diff --git a/docker-build/README.md b/docker-build/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4cbffa1b7eacfaffca4cb81ba6eb949d11eff48a --- /dev/null +++ b/docker-build/README.md @@ -0,0 +1,2 @@ +docker镜像创建方法: +进入到需要创建镜像的目录,运行build.sh即可 \ No newline at end of file diff --git a/docker-build/patroni-for-openGauss/build.sh b/docker-build/patroni-for-openGauss/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..44fd481ab225373e9176d29983d0ee011a242444 --- /dev/null +++ b/docker-build/patroni-for-openGauss/build.sh @@ -0,0 +1 @@ +docker build -t "opengauss:1.0.6" -f dockerfile . \ No newline at end of file diff --git a/docker-build/patroni-for-openGauss/dockerfile b/docker-build/patroni-for-openGauss/dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e8e45c966cad784ffa141cfb47d46b704a020d4d --- /dev/null +++ b/docker-build/patroni-for-openGauss/dockerfile @@ -0,0 +1,56 @@ +FROM centos:7.6.1810 + +# prepare openGauss +ENV LC_ALL=en_US.utf-8 +ENV LANG en_US.utf8 + +COPY gosu-amd64 /usr/local/bin/gosu +COPY wal2json.so /tmp +COPY etcd.conf.sample /home/omm/etcd.conf +COPY patroni.yaml.sample /home/omm/patroni.yaml +COPY psycopg2_binary-2.8.6-cp36-cp36m-manylinux1_x86_64.whl /var/lib +ENV GOSU_VERSION 1.12 +ADD patroni-2.0.2.tar.gz /var/lib +ADD etcd.tar.gz /usr/local/bin +ADD openGauss-2.0.1-CentOS-64bit.tar.bz2 /usr/local/opengauss + + +RUN set -eux; \ + yum install -y \ + python36 libaio libaio-devel libkeyutils-dev locales libreadline-dev python3-setuptools \ + python36-pip libpq-devel python-devel python3-devel gcc bind-utils &> /dev/null && \ + yum clean all && \ + groupadd -g 70 omm && \ + useradd -u 70 -g omm -d /home/omm -s /bin/bash omm; \ + mkdir -p /var/lib/opengauss && \ + mkdir -p /usr/local/opengauss && \ + mkdir -p /var/run/opengauss && \ + mkdir -p /home/omm && \ + mkdir -p /var/log/ && \ + mkdir /docker-entrypoint-initdb.d && \ + chown omm:omm /var/lib/opengauss /home/omm /var/run/opengauss /docker-entrypoint-initdb.d \ + /var/log /var/lib/patroni-2.0.2 -R && \ + cp /tmp/wal2json.so /usr/local/opengauss && \ + echo "export GAUSSHOME=/usr/local/opengauss" >> /home/omm/.bashrc && \ + echo "export PATH=\$GAUSSHOME/bin:\$PATH " >> /home/omm/.bashrc && \ + echo "export LD_LIBRARY_PATH=\$GAUSSHOME/lib:\$LD_LIBRARY_PATH" >> /home/omm/.bashrc && \ + chown omm:omm /home/omm -R && \ + chmod +x /usr/local/bin/gosu && \ + cd /var/lib/ && pip3 install psycopg2_binary-2.8.6-cp36-cp36m-manylinux1_x86_64.whl && \ + rm psycopg2_binary-2.8.6-cp36-cp36m-manylinux1_x86_64.whl -f && \ + pip3 install six python-etcd importlib-metadata && \ + cd /var/lib/patroni-2.0.2/ && python3 setup.py build && python3 setup.py install && \ + rm /var/lib/patroni-2.0.2/ -rf && \ + mv /etc/localtime /etc/localtime_bak && \ + cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime + +# prepare atcd and patroni + +ENV PGDATA /var/lib/opengauss/data + +COPY entrypoint.sh /usr/local/bin/ +RUN chmod 755 /usr/local/bin/entrypoint.sh;ln -s /usr/local/bin/entrypoint.sh / # backwards compat +ENTRYPOINT ["entrypoint.sh"] +EXPOSE 5432 +USER omm +CMD ["patroni"] diff --git a/docker-build/patroni-for-openGauss/entrypoint.sh b/docker-build/patroni-for-openGauss/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..5baf17a6ef32df26d7948971b246d2d228876a52 --- /dev/null +++ b/docker-build/patroni-for-openGauss/entrypoint.sh @@ -0,0 +1,624 @@ +#!/usr/bin/env bash +set -Eeo pipefail + +# usage: file_env VAR [DEFAULT] +# ie: file_env 'XYZ_DB_PASSWORD' 'example' +# (will allow for "$XYZ_DB_PASSWORD_FILE" to fill in the value of +# "$XYZ_DB_PASSWORD" from a file, especially for Docker's secrets feature) + +export GAUSSHOME=/usr/local/opengauss +export PATH=$GAUSSHOME/bin:$PATH +export LD_LIBRARY_PATH=$GAUSSHOME/lib:$LD_LIBRARY_PATH +export LANG=en_US.UTF-8 + +file_env() { + local var="$1" + local fileVar="${var}_FILE" + local def="${2:-}" + if [ "${!var:-}" ] && [ "${!fileVar:-}" ]; then + echo >&2 "error: both $var and $fileVar are set (but are exclusive)" + exit 1 + fi + local val="$def" + if [ "${!var:-}" ]; then + val="${!var}" + elif [ "${!fileVar:-}" ]; then + val="$(< "${!fileVar}")" + fi + export "$var"="$val" + unset "$fileVar" +} + +# check to see if this file is being run or sourced from another script +_is_sourced() { + [ "${#FUNCNAME[@]}" -ge 2 ] \ + && [ "${FUNCNAME[0]}" = '_is_sourced' ] \ + && [ "${FUNCNAME[1]}" = 'source' ] +} + +# used to create initial opengauss directories and if run as root, ensure ownership belong to the omm user +docker_create_db_directories() { + local user; user="$(id -u)" + + mkdir -p "$PGDATA" + chmod 700 "$PGDATA" + + # ignore failure since it will be fine when using the image provided directory; + mkdir -p /var/run/opengauss || : + chmod 775 /var/run/opengauss || : + + # Create the transaction log directory before initdb is run so the directory is owned by the correct user + if [ -n "$POSTGRES_INITDB_XLOGDIR" ]; then + mkdir -p "$POSTGRES_INITDB_XLOGDIR" + if [ "$user" = '0' ]; then + find "$POSTGRES_INITDB_XLOGDIR" \! -user postgres -exec chown postgres '{}' + + fi + chmod 700 "$POSTGRES_INITDB_XLOGDIR" + fi + + # allow the container to be started with `--user` + if [ "$user" = '0' ]; then + find "$PGDATA" \! -user omm -exec chown omm '{}' + + find /var/run/opengauss \! -user omm -exec chown omm '{}' + + fi +} + +# initialize empty PGDATA directory with new database via 'initdb' +# arguments to `initdb` can be passed via POSTGRES_INITDB_ARGS or as arguments to this function +# `initdb` automatically creates the "postgres", "template0", and "template1" dbnames +# this is also where the database user is created, specified by `GS_USER` env +docker_init_database_dir() { + # "initdb" is particular about the current user existing in "/etc/passwd", so we use "nss_wrapper" to fake that if necessary + if ! getent passwd "$(id -u)" &> /dev/null && [ -e /usr/lib/libnss_wrapper.so ]; then + export LD_PRELOAD='/usr/lib/libnss_wrapper.so' + export NSS_WRAPPER_PASSWD="$(mktemp)" + export NSS_WRAPPER_GROUP="$(mktemp)" + echo "postgres:x:$(id -u):$(id -g):PostgreSQL:$PGDATA:/bin/false" > "$NSS_WRAPPER_PASSWD" + echo "postgres:x:$(id -g):" > "$NSS_WRAPPER_GROUP" + fi + + if [ -n "$POSTGRES_INITDB_XLOGDIR" ]; then + set -- --xlogdir "$POSTGRES_INITDB_XLOGDIR" "$@" + fi + + gs_initdb -w "$GS_PASSWORD" --nodename=opengauss --encoding=UTF-8 --locale=en_US.UTF-8 --dbcompatibility=PG -D $PGDATA + # unset/cleanup "nss_wrapper" bits + if [ "${LD_PRELOAD:-}" = '/usr/lib/libnss_wrapper.so' ]; then + rm -f "$NSS_WRAPPER_PASSWD" "$NSS_WRAPPER_GROUP" + unset LD_PRELOAD NSS_WRAPPER_PASSWD NSS_WRAPPER_GROUP + fi +} + +# print large warning if GS_PASSWORD is long +# error if both GS_PASSWORD is empty and GS_HOST_AUTH_METHOD is not 'trust' +# print large warning if GS_HOST_AUTH_METHOD is set to 'trust' +# assumes database is not set up, ie: [ -z "$DATABASE_ALREADY_EXISTS" ] +docker_verify_minimum_env() { + # check password first so we can output the warning before postgres + # messes it up + if [[ "$GS_PASSWORD" =~ ^(.{8,}).*$ ]] && [[ "$GS_PASSWORD" =~ ^(.*[a-z]+).*$ ]] && [[ "$GS_PASSWORD" =~ ^(.*[A-Z]).*$ ]] && [[ "$GS_PASSWORD" =~ ^(.*[0-9]).*$ ]] && [[ "$GS_PASSWORD" =~ ^(.*[#?!@$%^&*-]).*$ ]]; then + cat >&2 <<-'EOWARN' + + Message: The supplied GS_PASSWORD is meet requirements. + +EOWARN + else + cat >&2 <<-'EOWARN' + + Error: The supplied GS_PASSWORD is not meet requirements. + Please Check if the password contains uppercase, lowercase, numbers, special characters, and password length(8). + At least one uppercase, lowercase, numeric, special character. + Example: Enmo@123 +EOWARN + exit 1 + fi + if [ -z "$GS_PASSWORD" ] && [ 'trust' != "$GS_HOST_AUTH_METHOD" ]; then + # The - option suppresses leading tabs but *not* spaces. :) + cat >&2 <<-'EOE' + Error: Database is uninitialized and superuser password is not specified. + You must specify GS_PASSWORD to a non-empty value for the + superuser. For example, "-e GS_PASSWORD=password" on "docker run". + + You may also use "GS_HOST_AUTH_METHOD=trust" to allow all + connections without a password. This is *not* recommended. + +EOE + exit 1 + fi + if [ 'trust' = "$GS_HOST_AUTH_METHOD" ]; then + cat >&2 <<-'EOWARN' + ******************************************************************************** + WARNING: GS_HOST_AUTH_METHOD has been set to "trust". This will allow + anyone with access to the opengauss port to access your database without + a password, even if GS_PASSWORD is set. + It is not recommended to use GS_HOST_AUTH_METHOD=trust. Replace + it with "-e GS_PASSWORD=password" instead to set a password in + "docker run". + ******************************************************************************** +EOWARN + fi +} + +# usage: docker_process_init_files [file [file [...]]] +# ie: docker_process_init_files /always-initdb.d/* +# process initializer files, based on file extensions and permissions +docker_process_init_files() { + # gsql here for backwards compatiblilty "${gsql[@]}" + gsql=( docker_process_sql ) + + echo + local f + for f; do + case "$f" in + *.sh) + if [ -x "$f" ]; then + echo "$0: running $f" + "$f" + else + echo "$0: sourcing $f" + . "$f" + fi + ;; + *.sql) echo "$0: running $f"; docker_process_sql -f "$f"; echo ;; + *.sql.gz) echo "$0: running $f"; gunzip -c "$f" | docker_process_sql; echo ;; + *.sql.xz) echo "$0: running $f"; xzcat "$f" | docker_process_sql; echo ;; + *) echo "$0: ignoring $f" ;; + esac + echo + done +} + +# Execute sql script, passed via stdin (or -f flag of pqsl) +# usage: docker_process_sql [gsql-cli-args] +# ie: docker_process_sql --dbname=mydb <<<'INSERT ...' +# ie: docker_process_sql -f my-file.sql +# ie: docker_process_sql > $PGDATA/pg_hba.conf +} + +# append parameter to postgres.conf for connections +opengauss_setup_postgresql_conf() { + if [ -n "$PORT" ]; then + gs_guc set -D $PGDATA -c "port=$PORT" + else + gs_guc set -D $PGDATA -c "PORT=5432" + fi + gs_guc set -D $PGDATA -c "password_encryption_type = 1" \ + -c "wal_level=logical" \ + -c "max_wal_senders=16" \ + -c "max_replication_slots=9" \ + -c "wal_sender_timeout=0s" \ + -c "wal_receiver_timeout=0s" + + if [ -n "$SERVER_MODE" ]; then + gs_guc set -D $PGDATA -c "listen_addresses = '${HOST_IP}'" \ + -c "most_available_sync = on" \ + -c "remote_read_mode = off" \ + -c "pgxc_node_name = '$HOST_NAME'" \ + -c "application_name = '$HOST_NAME'" + set_REPLCONNINFO + if [ -n "$SYNCHRONOUS_STANDBY_NAMES" ]; then + gs_guc set -D $PGDATA -c "synchronous_standby_names=$SYNCHRONOUS_STANDBY_NAMES" + fi + else + gs_guc set -D $PGDATA -c "listen_addresses = '*'" + fi + + if [ -n "$db_config" ]; then + OLD_IFS="$IFS" + IFS="#" + db_config=($db_config) + for s in ${db_config[@]}; do + gs_guc set -D $PGDATA -c "$s" + done + IFS="$OLD_IFS" + fi + if [ -f "/tmp/db_config.conf" ]; then + cat /tmp/db_config.conf >> "$PGDATA/postgresql.conf" + fi +} + +opengauss_setup_mot_conf() { + echo "enable_numa = false" >> "$PGDATA/mot.conf" +} + +# start socket-only postgresql server for setting up or running scripts +# all arguments will be passed along as arguments to `postgres` (via pg_ctl) +docker_temp_server_start() { + if [ "$1" = 'gaussdb' ]; then + shift + fi + + # internal start of server in order to allow setup using gsql client + # does not listen on external TCP/IP and waits until start finishes + set -- "$@" -c listen_addresses='' -p "${PORT:-5432}" + + PGUSER="${PGUSER:-$GS_USER}" \ + gs_ctl -D "$PGDATA" \ + -o "$(printf '%q ' "$@")" \ + -w start +} + +# stop postgresql server after done setting up user and running scripts +docker_temp_server_stop() { + PGUSER="${PGUSER:-postgres}" \ + gs_ctl -D "$PGDATA" -m fast -w stop +} + +docker_slave_full_backup() { + echo "rebuild standby" + set +e + while : + do + gs_ctl restart -D "$PGDATA" -M $SERVER_MODE + gs_ctl build -D "$PGDATA" -M $SERVER_MODE -b full + if [ $? -eq 0 ]; then + break + else + echo "errcode=$?" + echo "build failed" + sleep 1s + fi + done + set -e +} + +_create_config_og() { + docker_setup_env + # setup data directories and permissions (when run as root) + docker_create_db_directories + if [ "$(id -u)" = '0' ]; then + # then restart script as postgres user + exec gosu omm "$BASH_SOURCE" "$@" + fi + + # only run initialization on an empty data directory + if [ -z "$DATABASE_ALREADY_EXISTS" ]; then + docker_verify_minimum_env + + # check dir permissions to reduce likelihood of half-initialized database + ls /docker-entrypoint-initdb.d/ > /dev/null + + docker_init_database_dir + opengauss_setup_hba_conf + opengauss_setup_postgresql_conf + opengauss_setup_mot_conf + + # PGPASSWORD is required for gsql when authentication is required for 'local' connections via pg_hba.conf and is otherwise harmless + # e.g. when '--auth=md5' or '--auth-local=md5' is used in POSTGRES_INITDB_ARGS + export PGPASSWORD="${PGPASSWORD:-$GS_PASSWORD}" + docker_temp_server_start "$@" + if [ -z "$SERVER_MODE" ] || [ "$SERVER_MODE" = "primary" ]; then + docker_setup_db + docker_setup_user + docker_process_init_files /docker-entrypoint-initdb.d/* + fi + + if [ -n "$SERVER_MODE" ] && [ "$SERVER_MODE" != "primary" ]; then + docker_slave_full_backup + fi + docker_temp_server_stop + unset PGPASSWORD + + echo + echo 'openGauss init process complete; ready for start up.' + echo + else + echo + echo 'openGauss Database directory appears to contain a database; Skipping initialization' + echo + fi +} + +# process PEER_IPS, PEER_HOST_NAMES +# uses environment variables for input: PEER_IPS, PEER_HOST_NAMES +process_check_PEERS () { + # process PEER_IPS and PEER_HOST_NAMES to array + PEER_IPS_ARR=(${PEER_IPS//,/ }) + PEER_HOST_NAMES_ARR=(${PEER_HOST_NAMES//,/ }) + local len_ips=${#PEER_IPS_ARR[*]} + local len_names=${#PEER_HOST_NAMES_ARR[*]} + echo "len_ips=$len_ips" + echo "len_names=$len_names" + if [ ${len_ips} -ne ${len_names} ]; then + cat >&2 <<-'EOE' + Error: PEER_IPS are not matched with PEER_HOST_NAMES! + +EOE + exit 1 + fi + if [ ${len_ips} -gt 8 ]; then + cat >&2 <<-'EOE' + Error: Opengauss support 8 standbies at most! + +EOE + exit 1 + fi + set +e + for i in $(seq 0 $(($len_ips - 1))); do + while : + do + + local tempip=`host ${PEER_IPS_ARR[$i]} | grep -Eo "[0-9]+.[0-9]+.[0-9]+.[0-9]+"` + if [ -n "$tempip" ]; then + PEER_IPS_ARR[$i]="$tempip" + break + else + sleep 1s + fi + done + done + set -e + PEER_NUM=$len_ips + echo "export STANDBY_NUM=$PEER_NUM" >> /home/omm/.bashrc +} + +# get etcd's parameter ETCD_MEMBERS +get_ETCD_MEMBERS () { + echo "----get_ETCD_MEMBERS-----" + ETCD_MEMBERS="${HOST_NAME}=http://${HOST_IP}:2380" + echo "ETCD_MEMBERS=$ETCD_MEMBERS" + local len=$(($PEER_NUM - 1)) + for i in $(seq 0 ${len}); do + echo "${i} ${PEER_HOST_NAMES_ARR[$i]} ${PEER_IPS_ARR[$i]}" + ETCD_MEMBERS="${ETCD_MEMBERS},${PEER_HOST_NAMES_ARR[$i]}=http://${PEER_IPS_ARR[$i]}:2380" + done + echo "ETCD_MEMBERS=$ETCD_MEMBERS" +} + +# get database's parameter replconninfoi +# uses environment variables for input: HOST_IP, PORT +get_replconninfoi () { + replconninfoi="localhost=${HOST_IP} localport=$((${PORT} + 1)) localheartbeatport=$((${PORT} + 2)) localservice=$((${PORT} + 4)) remotehost=$1 remoteport=$(($2 + 1)) remoteheartbeatport=$(($2 + 2)) remoteservice=$(($2 + 4))" +} + +# set database's parameter REPL_CONN_INFO +# uses environment variables for input: PEER_IPS +set_REPLCONNINFO () { + REPL_CONN_INFO="" + local len=$(($PEER_NUM - 1)) + for i in $(seq 0 $len); do + get_replconninfoi "${PEER_IPS_ARR[$i]}" $PORT + gs_guc set -D $PGDATA -c "replconninfo$((${i} + 1)) = '${replconninfoi}'" + done +} + +# change etcd's config +# uses environment variables for input: HOST_NAME, HOST_IP, INITIAL_CLUSTER_STATE +change_etcd_config() { + get_ETCD_MEMBERS + sed -i "s/^name: 'default'/name: '${HOST_NAME}'/" /home/omm/etcd.conf && \ + sed -i "s/^listen-peer-urls: http:\/\/localhost:2380/listen-peer-urls: http:\/\/${HOST_IP}:2380/" /home/omm/etcd.conf && \ + sed -i "s/^initial-advertise-peer-urls: http:\/\/localhost:2380/initial-advertise-peer-urls: http:\/\/${HOST_IP}:2380/" /home/omm/etcd.conf && \ + sed -i "s/^advertise-client-urls: http:\/\/localhost:2379/advertise-client-urls: http:\/\/${HOST_IP}:2379/" /home/omm/etcd.conf && \ + sed -i "s|^initial-cluster: initial-cluster|initial-cluster: ${ETCD_MEMBERS}|" /home/omm/etcd.conf + if [ -n "${INITIAL_CLUSTER_STATE}" ] && [ "${INITIAL_CLUSTER_STATE}" == "existing" ]; then + sed -i "s/initial-cluster-state: 'new'/initial-cluster-state: 'existing'/" /home/omm/etcd.conf + fi +} + +# get ETCD_HOSTS +get_ETCD_HOSTS () { + ETCD_HOSTS="${HOST_IP}:2379" + for i in $(seq 0 $len); do + ETCD_HOSTS="${ETCD_HOSTS},${PEER_IPS_ARR[$i]}:2379" + done +} + +# change patroni's config +# uses environment variables for input: HOST_NAME, HOST_IP, PORT, GS_PASSWORD, GS_PASSWORD +change_patroni_config() { + get_ETCD_HOSTS + sed -i "s/^name: name/name: ${HOST_NAME}/" /home/omm/patroni.yaml && \ + sed -i "s/^ listen: localhost:8008/ listen: ${HOST_IP}:8008/" /home/omm/patroni.yaml && \ + sed -i "s/^ connect_address: localhost:8008/ connect_address: ${HOST_IP}:8008/" /home/omm/patroni.yaml && \ + sed -i "s/^ host: localhost:2379/ hosts: ${ETCD_HOSTS}/" /home/omm/patroni.yaml && \ + sed -i "s/^ listen: localhost:16000/ listen: ${HOST_IP}:${PORT}/" /home/omm/patroni.yaml && \ + sed -i "s/^ connect_address: localhost:16000/ connect_address: ${HOST_IP}:${PORT}/" /home/omm/patroni.yaml + if [ -n "$GS_USERNAME" ] && [ "$GS_USERNAME" != "admin" ]; then + sed -i "s/^ username: admin/ username: $GS_USERNAME/" /home/omm/patroni.yaml + fi + sed -i "s/^ password: huawei_123/ password: $GS_PASSWORD/" /home/omm/patroni.yaml +} + +# add new members +# uses environment variables for input: +add_standby () { + source /home/omm/.bashrc + echo "STANDBY_NUM=$STANDBY_NUM" + if [ $STANDBY_NUM -gt 8 ]; then + cat >&2 <<-'EOE' + Error: Opengauss support 8 standbies at most and there are already 8 standbies now! + +EOE + exit 1 + fi + echo "NEW_MEMBER_IPS=$NEW_MEMBER_IPS" + echo "NEW_MEMBER_NAMES=$NEW_MEMBER_NAMES" + NEW_MEMBER_IPS_ARR=(${NEW_MEMBER_IPS//,/ }) + NEW_MEMBER_NAMES_ARR=(${NEW_MEMBER_NAMES//,/ }) + echo "NEW_MEMBER_IPS_ARR=${NEW_MEMBER_IPS_ARR[*]}" + echo "NEW_MEMBER_NAMES_ARR=${NEW_MEMBER_NAMES_ARR[*]}" + local len_ips=${#NEW_MEMBER_IPS_ARR[*]} + local len_names=${#NEW_MEMBER_NAMES_ARR[*]} + echo "len_ips=$len_ips" + echo "len_names=$len_names" + if [ $len_ips -ne $len_names ]; then + cat >&2 <<-'EOE' + Error: NEW_MEMBER_IPS are not matched with NEW_MEMBER_IPS! + +EOE + exit 1 + fi + if [ $len_ips -eq 0 ]; then + cat >&2 <<-'EOE' + Error: No new members! + +EOE + exit 1 + fi + if [ $(($STANDBY_NUM + len_ips)) -gt 8 ]; then + cat >&2 <<-'EOE' + Error: The cluster has already $STANDBY_NUM standbies now, so $len_ips standbies can't be added! + +EOE + exit 1 + fi + local len=$(($len_ips - 1)) + local member_list=`etcdctl member list` + echo -e "member_list=$member_list" + for i in $(seq 0 $len); do + if [[ $member_list =~ " started, ${NEW_MEMBER_NAMES_ARR[$i]}" ]]; then + echo "${NEW_MEMBER_IPS[$i]} has already been in the cluster." + else + while : + do + host ${NEW_MEMBER_IPS_ARR[$i]} && echo "" > /dev/null + if [ $? -eq 0 ]; then + NEW_MEMBER_IPS_ARR[$i]=`host ${NEW_MEMBER_IPS_ARR[$i]} | grep -Eo "[0-9]+.[0-9]+.[0-9]+.[0-9]+"` + echo "NEW_MEMBER_IPS: $i ${NEW_MEMBER_IPS_ARR[$i]}" + break + fi + done + if [[ $member_list == *unstarted*${NEW_MEMBER_IPS_ARR[$i]}* ]]; then + echo "${NEW_MEMBER_NAMES_ARR[$i]} has already been in the etcd cluster." + else + etcdctl member add ${NEW_MEMBER_NAMES_ARR[$i]} --peer-urls="http://${NEW_MEMBER_IPS_ARR[$i]}:2380" + fi + get_replconninfoi "${NEW_MEMBER_IPS_ARR[$i]}" $PORT + gs_guc reload -D $PGDATA -c "replconninfo$((${STANDBY_NUM} + ${i} + 1 ))='${replconninfoi}'" + fi + done + sed -i "s|STANDBY_NUM=${STANDBY_NUM}|STANDBY_NUM=$(($STANDBY_NUM + len_ips))|" /home/omm/.bashrc + echo "Etcd and database is ready to join the new member. Please start the new member." +} + +_main() { + if [ "$(id -u)" = '0' ]; then + id + # then restart script as postgres user + if [ -d "/var/lib/opengauss/data/" ]; then + chown omm:omm /var/lib/opengauss/data/ -R + fi + exec gosu omm "$BASH_SOURCE" "$@" + elif [ $# = 1 ] && [ "$1" = "patroni" ]; then + process_check_PEERS + # change etcd config file + echo "-------------------------change etcd config-------------------------" + change_etcd_config + # start etcd + echo "-------------------------start etcd-------------------------" + etcd --config-file /home/omm/etcd.conf > /var/log/etcd.log 2>&1 & + echo "-------------------------etcd.log-------------------------" + sleep 1s + cat /var/log/etcd.log + + # create and config database + echo "-------------------------prepare start-------------------------" + echo "-------------------------create and config opengauss-------------------------" + if [ "`ls -A $PGDATA`" = "" ]; then + _create_config_og + else + echo "database directory has already been exist" + cp "$PGDATA/pg_hba0.conf" "$PGDATA/pg_hba.conf" -f + opengauss_setup_hba_conf + opengauss_setup_postgresql_conf + if [ -n "$SERVER_MODE" ] && [ "$SERVER_MODE" != "primary" ]; then + docker_slave_full_backup + docker_temp_server_stop + fi + fi + + # change patroni config file + change_patroni_config + # start patroni + source /home/omm/.bashrc + exec patroni /home/omm/patroni.yaml 2>&1 | tee /var/log/patroni.log + elif [ "$1" = "list" ] || [ "$1" = "switchover" ] || [ "$1" = "failover" ]; then + patronictl -c /home/omm/patroni.yaml $1 + elif [ "$1" = "add_standby" ]; then + add_standby "$@" + else + exec "$@" + fi +} + +if ! _is_sourced; then + _main "$@" +fi diff --git a/docker-build/patroni-for-openGauss/etcd.conf.sample b/docker-build/patroni-for-openGauss/etcd.conf.sample new file mode 100644 index 0000000000000000000000000000000000000000..cbf3bce71d07b66078039a769c29043f2a32022a --- /dev/null +++ b/docker-build/patroni-for-openGauss/etcd.conf.sample @@ -0,0 +1,140 @@ +# This is the configuration file for the etcd server. + +# Human-readable name for this member. +name: 'default' + +# Path to the data directory. +data-dir: /tmp/etcd.data + +# Path to the dedicated wal directory. +wal-dir: /tmp/etcd.wal + +# Number of committed transactions to trigger a snapshot to disk. +snapshot-count: 10000 + +# Time (in milliseconds) of a heartbeat interval. +heartbeat-interval: 100 + +# Time (in milliseconds) for an election to timeout. +election-timeout: 1000 + +# Raise alarms when backend size exceeds the given quota. 0 means use the +# default quota. +quota-backend-bytes: 0 + +# List of comma separated URLs to listen on for peer traffic. +listen-peer-urls: http://localhost:2380 + +# List of comma separated URLs to listen on for client traffic. +listen-client-urls: http://0.0.0.0:2379 + +# Maximum number of snapshot files to retain (0 is unlimited). +max-snapshots: 5 + +# Maximum number of wal files to retain (0 is unlimited). +max-wals: 5 + +# Comma-separated white list of origins for CORS (cross-origin resource sharing). +cors: + +# List of this member's peer URLs to advertise to the rest of the cluster. +# The URLs needed to be a comma-separated list. +initial-advertise-peer-urls: http://localhost:2380 + +# List of this member's client URLs to advertise to the public. +# The URLs needed to be a comma-separated list. +advertise-client-urls: http://localhost:2379 + +# Discovery URL used to bootstrap the cluster. +discovery: + +# Valid values include 'exit', 'proxy' +discovery-fallback: 'proxy' + +# HTTP proxy to use for traffic to discovery service. +discovery-proxy: + +# DNS domain used to bootstrap initial cluster. +discovery-srv: + +# Initial cluster configuration for bootstrapping. +initial-cluster: initial-cluster + +# Initial cluster token for the etcd cluster during bootstrap. +initial-cluster-token: 'etcd-cluster' + +# Initial cluster state ('new' or 'existing'). +initial-cluster-state: 'new' + +# Reject reconfiguration requests that would cause quorum loss. +strict-reconfig-check: false + +# Accept etcd V2 client requests +enable-v2: true + +# Enable runtime profiling data via HTTP server +enable-pprof: true + +# Valid values include 'on', 'readonly', 'off' +proxy: 'off' + +# Time (in milliseconds) an endpoint will be held in a failed state. +proxy-failure-wait: 5000 + +# Time (in milliseconds) of the endpoints refresh interval. +proxy-refresh-interval: 30000 + +# Time (in milliseconds) for a dial to timeout. +proxy-dial-timeout: 1000 + +# Time (in milliseconds) for a write to timeout. +proxy-write-timeout: 5000 + +# Time (in milliseconds) for a read to timeout. +proxy-read-timeout: 0 + +client-transport-security: + # Path to the client server TLS cert file. + cert-file: + + # Path to the client server TLS key file. + key-file: + + # Enable client cert authentication. + client-cert-auth: false + + # Path to the client server TLS trusted CA cert file. + trusted-ca-file: + + # Client TLS using generated certificates + auto-tls: false + +peer-transport-security: + # Path to the peer server TLS cert file. + cert-file: + + # Path to the peer server TLS key file. + key-file: + + # Enable peer client cert authentication. + client-cert-auth: false + + # Path to the peer server TLS trusted CA cert file. + trusted-ca-file: + + # Peer TLS using generated certificates. + auto-tls: false + +# Enable debug-level logging for etcd. +log-level: debug + +logger: zap + +# Specify 'stdout' or 'stderr' to skip journald logging even when running under systemd. +log-outputs: [stderr] + +# Force to create a new one member cluster. +force-new-cluster: false + +auto-compaction-mode: periodic +auto-compaction-retention: "1" diff --git a/docker-build/patroni-for-openGauss/gosu-amd64 b/docker-build/patroni-for-openGauss/gosu-amd64 new file mode 100644 index 0000000000000000000000000000000000000000..834951f8910b8f4f88d626a05c9c0cdce527f967 Binary files /dev/null and b/docker-build/patroni-for-openGauss/gosu-amd64 differ diff --git a/docker-build/patroni-for-openGauss/gosu-arm64 b/docker-build/patroni-for-openGauss/gosu-arm64 new file mode 100644 index 0000000000000000000000000000000000000000..925e16ed1ff020e522c815ae4b6b8349634ef30a Binary files /dev/null and b/docker-build/patroni-for-openGauss/gosu-arm64 differ diff --git a/docker-build/patroni-for-openGauss/openGauss-2.0.1-CentOS-64bit.tar.bz2 b/docker-build/patroni-for-openGauss/openGauss-2.0.1-CentOS-64bit.tar.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..e55a4d3a640faa2b147e0f2407151343dd3bfb98 Binary files /dev/null and b/docker-build/patroni-for-openGauss/openGauss-2.0.1-CentOS-64bit.tar.bz2 differ diff --git a/docker-build/patroni-for-openGauss/patroni-2.0.2.tar.gz b/docker-build/patroni-for-openGauss/patroni-2.0.2.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..7bc9e041cc4769ddc9a984903a57850d88c5bf42 Binary files /dev/null and b/docker-build/patroni-for-openGauss/patroni-2.0.2.tar.gz differ diff --git a/docker-build/patroni-for-openGauss/patroni.yaml.sample b/docker-build/patroni-for-openGauss/patroni.yaml.sample new file mode 100644 index 0000000000000000000000000000000000000000..6c4c5a0302dbd0ded8d148360c9d0b6ca5a4ef49 --- /dev/null +++ b/docker-build/patroni-for-openGauss/patroni.yaml.sample @@ -0,0 +1,66 @@ +scope: opengauss +namespace: /service +name: name + +restapi: + listen: localhost:8008 + connect_address: localhost:8008 + +etcd: + host: localhost:2379 + +bootstrap: + dcs: + ttl: 30 + loop_wait: 10 + retry_timeout: 10 + maximum_lag_on_failover: 1048576 + master_start_timeout: 300 + synchronous_mode: false + postgresql: + use_pg_rewind: true + use_slots: true + parameters: + wal_level: hotstandby + hot_standby: "on" + wal_keep_segments: 16 + max_wal_sender: 10 + max_replication_slots: 10 + wal_log_hints: "on" + + + + initdb: + - encoding: UTF8 + - data-checksums + + + + +postgresql: + listen: localhost:16000 + connect_address: localhost:16000 + data_dir: /var/lib/opengauss/data + bin_dir: /usr/local/opengauss/bin + config_dir: /var/lib/opengauss/data + custom_conf: /var/lib/opengauss/data/postgresql.conf + + authentication: + replication: + username: admin + password: huawei_123 + superuser: + username: admin + password: huawei_123 + rewind: + username: admin + password: huawei_123 + + + + +tags: + nofailover: false + noloadbalance: false + clonefrom: false + nosync: false diff --git a/docker-build/patroni-for-openGauss/psycopg2-2.8.6.tar.gz b/docker-build/patroni-for-openGauss/psycopg2-2.8.6.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..2d586450fade2413e10a39e4d2936ca8f7c87fa6 Binary files /dev/null and b/docker-build/patroni-for-openGauss/psycopg2-2.8.6.tar.gz differ diff --git a/docker-build/patroni-for-openGauss/psycopg2_binary-2.8.6-cp36-cp36m-manylinux1_x86_64.whl b/docker-build/patroni-for-openGauss/psycopg2_binary-2.8.6-cp36-cp36m-manylinux1_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..4009ef36f37b6004954e6c61e17b0af4576519c0 Binary files /dev/null and b/docker-build/patroni-for-openGauss/psycopg2_binary-2.8.6-cp36-cp36m-manylinux1_x86_64.whl differ diff --git a/docker-build/patroni-for-openGauss/wal2json.so b/docker-build/patroni-for-openGauss/wal2json.so new file mode 100644 index 0000000000000000000000000000000000000000..fc1effe31f424172f7fbab803b9e879a11cfe2ac Binary files /dev/null and b/docker-build/patroni-for-openGauss/wal2json.so differ diff --git a/docker-build/shardingsphere/apache-zookeeper-3.7.0-bin.tar.gz b/docker-build/shardingsphere/apache-zookeeper-3.7.0-bin.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..0d778f702ca74e0ebd458fbce01bf1e984437979 Binary files /dev/null and b/docker-build/shardingsphere/apache-zookeeper-3.7.0-bin.tar.gz differ diff --git a/docker-build/shardingsphere/build.sh b/docker-build/shardingsphere/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..05f2ff3636f6b159b356e29198f8975c334dc675 --- /dev/null +++ b/docker-build/shardingsphere/build.sh @@ -0,0 +1 @@ +docker build -t "shardingsphere:1.0.4" -f dockerfile . \ No newline at end of file diff --git a/docker-build/shardingsphere/dockerfile b/docker-build/shardingsphere/dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..27a98e5918d92a8f2e229efa12be872262b5e2d1 --- /dev/null +++ b/docker-build/shardingsphere/dockerfile @@ -0,0 +1,21 @@ +FROM java:8 + +ADD apache-zookeeper-3.7.0-bin.tar.gz /usr/ +ADD shardingsphere-scaling.tar.gz /usr/ +ADD shardingsphere-proxy.tar.gz /usr/ +COPY entrypoint.sh / + +ARG APP_NAME +ENV PROXY_PATH /usr/shardingsphere-proxy +ENV SCALING_PATH /usr/shardingsphere-scaling +ENV ZOOKEEPER_PATH /usr/apache-zookeeper-3.7.0-bin + +RUN set -eux; \ + mkdir -p ${PROXY_PATH}/ext-lib && \ + mkdir -p ${SCALING_PATH}/ext-lib && \ + mv ${ZOOKEEPER_PATH}/conf/zoo_sample.cfg ${ZOOKEEPER_PATH}/conf/zoo.cfg && \ + chmod 755 /entrypoint.sh && \ + echo "export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64" >> /etc/profile + +EXPOSE 8888 +ENTRYPOINT ["/entrypoint.sh"] diff --git a/docker-build/shardingsphere/entrypoint.sh b/docker-build/shardingsphere/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..9de06cd1819869305164dea0b9fc986de9f681f5 --- /dev/null +++ b/docker-build/shardingsphere/entrypoint.sh @@ -0,0 +1,38 @@ +#!/bin/bash -e +# config +#!/bin/bash +set -Eeo pipefail +if [ ! -f "/tmp/config-sharding.yaml" ]; then + cat >&2 <<-'EOE' + Error: Config file dose not exist! + +EOE + exit 1 +fi +cp /tmp/config-sharding.yaml ${PROXY_PATH}/conf/config-sharding.yaml -f +if [ -f "/tmp/server.yaml" ]; then + cp /tmp/server.yaml ${PROXY_PATH}/conf/server.yaml -f +fi +if [ -f "/tmp/config-database-discovery.yaml" ]; then + cp /tmp/config-database-discovery.yaml ${PROXY_PATH}/conf/config-database-discovery.yaml -f +fi +if [ -f "/tmp/config-encrypt.yaml" ]; then + cp /tmp/config-encrypt.yaml ${PROXY_PATH}/conf/config-encrypt.yaml -f +fi +if [ -f "/tmp/config-readwrite-splitting.yaml" ]; then + cp /tmp/config-readwrite-splitting.yaml ${PROXY_PATH}/conf/config-readwrite-splitti.yaml -f +fi +if [ -f "/tmp/config-shadow.yaml" ]; then + cp /tmp/config-shadow.yaml ${PROXY_PATH}/conf/config-shadow.yaml -f +fi +if [ -f "/tmp/logback.xml" ]; then + cp /tmp/logback.xml ${PROXY_PATH}/conf/logback.xml -f +fi +if [ -f "/tmp/scaling_server.yaml" ]; then + cp /tmp/scaling_server.xml ${SCALING_PATH}/conf/server.xml -f +fi + +nohup ${ZOOKEEPER_PATH}/bin/zkServer.sh start & +sleep 3 +nohup ${SCALING_PATH}/bin/start.sh server & +${PROXY_PATH}/bin/start.sh && tail -f ${PROXY_PATH}/logs/stdout.log diff --git a/docker-build/shardingsphere/shardingsphere-proxy.tar.gz b/docker-build/shardingsphere/shardingsphere-proxy.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..ccdebca7b9276359f21e6c8dcd9c1ac82f433885 Binary files /dev/null and b/docker-build/shardingsphere/shardingsphere-proxy.tar.gz differ diff --git a/docker-build/shardingsphere/shardingsphere-scaling.tar.gz b/docker-build/shardingsphere/shardingsphere-scaling.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..209d47897fac345d44ceedcfcfb68780d27c2415 Binary files /dev/null and b/docker-build/shardingsphere/shardingsphere-scaling.tar.gz differ diff --git a/patroni-for-openGauss/__init__.py b/patroni-for-openGauss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3463a56281f7ceaf6cccb95d9c91756877439a --- /dev/null +++ b/patroni-for-openGauss/__init__.py @@ -0,0 +1,205 @@ +import logging +import os +import signal +import sys +import time + +from .daemon import AbstractPatroniDaemon, abstract_main +from .version import __version__ + +logger = logging.getLogger(__name__) + +PATRONI_ENV_PREFIX = 'PATRONI_' +KUBERNETES_ENV_PREFIX = 'KUBERNETES_' + + +class Patroni(AbstractPatroniDaemon): + + def __init__(self, config): + from patroni.api import RestApiServer + from patroni.dcs import get_dcs + from patroni.ha import Ha + from patroni.postgresql import Postgresql + from patroni.request import PatroniRequest + from patroni.watchdog import Watchdog + + super(Patroni, self).__init__(config) + + self.version = __version__ + self.dcs = get_dcs(self.config) + self.watchdog = Watchdog(self.config) + self.load_dynamic_configuration() + + self.postgresql = Postgresql(self.config['postgresql']) + self.api = RestApiServer(self, self.config['restapi']) + self.request = PatroniRequest(self.config, True) + self.ha = Ha(self) + + self.tags = self.get_tags() + self.next_run = time.time() + self.scheduled_restart = {} + + def load_dynamic_configuration(self): + from patroni.exceptions import DCSError + while True: + try: + cluster = self.dcs.get_cluster() + if cluster and cluster.config and cluster.config.data: + if self.config.set_dynamic_configuration(cluster.config): + self.dcs.reload_config(self.config) + self.watchdog.reload_config(self.config) + elif not self.config.dynamic_configuration and 'bootstrap' in self.config: + if self.config.set_dynamic_configuration(self.config['bootstrap']['dcs']): + self.dcs.reload_config(self.config) + break + except DCSError: + logger.warning('Can not get cluster from dcs') + time.sleep(5) + + def get_tags(self): + return {tag: value for tag, value in self.config.get('tags', {}).items() + if tag not in ('clonefrom', 'nofailover', 'noloadbalance', 'nosync') or value} + + @property + def nofailover(self): + return bool(self.tags.get('nofailover', False)) + + @property + def nosync(self): + return bool(self.tags.get('nosync', False)) + + def reload_config(self, sighup=False, local=False): + try: + super(Patroni, self).reload_config(sighup, local) + if local: + self.tags = self.get_tags() + self.request.reload_config(self.config) + self.api.reload_config(self.config['restapi']) + self.watchdog.reload_config(self.config) + self.postgresql.reload_config(self.config['postgresql'], sighup) + self.dcs.reload_config(self.config) + except Exception: + logger.exception('Failed to reload config_file=%s', self.config.config_file) + + @property + def replicatefrom(self): + return self.tags.get('replicatefrom') + + @property + def noloadbalance(self): + return bool(self.tags.get('noloadbalance', False)) + + def schedule_next_run(self): + self.next_run += self.dcs.loop_wait + current_time = time.time() + nap_time = self.next_run - current_time + if nap_time <= 0: + self.next_run = current_time + # Release the GIL so we don't starve anyone waiting on async_executor lock + time.sleep(0.001) + # Warn user that Patroni is not keeping up + logger.warning("Loop time exceeded, rescheduling immediately.") + elif self.ha.watch(nap_time): + self.next_run = time.time() + + def run(self): + self.api.start() + self.next_run = time.time() + super(Patroni, self).run() + + def _run_cycle(self): + logger.info(self.ha.run_cycle()) + + if self.dcs.cluster and self.dcs.cluster.config and self.dcs.cluster.config.data \ + and self.config.set_dynamic_configuration(self.dcs.cluster.config): + self.reload_config() + + if self.postgresql.role != 'uninitialized': + self.config.save_cache() + + self.schedule_next_run() + + def _shutdown(self): + try: + self.api.shutdown() + except Exception: + logger.exception('Exception during RestApi.shutdown') + try: + self.ha.shutdown() + except Exception: + logger.exception('Exception during Ha.shutdown') + + +def patroni_main(): + from multiprocessing import freeze_support + from patroni.validator import schema + + freeze_support() + abstract_main(Patroni, schema) + + +def fatal(string, *args): + sys.stderr.write('FATAL: ' + string.format(*args) + '\n') + sys.exit(1) + + +def check_psycopg2(): + min_psycopg2 = (2, 5, 4) + min_psycopg2_str = '.'.join(map(str, min_psycopg2)) + + def parse_version(version): + for e in version.split('.'): + try: + yield int(e) + except ValueError: + break + + try: + import psycopg2 + version_str = psycopg2.__version__.split(' ')[0] + version = tuple(parse_version(version_str)) + if version < min_psycopg2: + fatal('Patroni requires psycopg2>={0}, but only {1} is available', min_psycopg2_str, version_str) + except ImportError: + fatal('Patroni requires psycopg2>={0} or psycopg2-binary', min_psycopg2_str) + + +def main(): + if os.getpid() != 1: + check_psycopg2() + return patroni_main() + + # Patroni started with PID=1, it looks like we are in the container + pid = 0 + + # Looks like we are in a docker, so we will act like init + def sigchld_handler(signo, stack_frame): + try: + while True: + ret = os.waitpid(-1, os.WNOHANG) + if ret == (0, 0): + break + elif ret[0] != pid: + logger.info('Reaped pid=%s, exit status=%s', *ret) + except OSError: + pass + + def passtochild(signo, stack_frame): + if pid: + os.kill(pid, signo) + + if os.name != 'nt': + signal.signal(signal.SIGCHLD, sigchld_handler) + signal.signal(signal.SIGHUP, passtochild) + signal.signal(signal.SIGQUIT, passtochild) + signal.signal(signal.SIGUSR1, passtochild) + signal.signal(signal.SIGUSR2, passtochild) + signal.signal(signal.SIGINT, passtochild) + signal.signal(signal.SIGABRT, passtochild) + signal.signal(signal.SIGTERM, passtochild) + + import multiprocessing + patroni = multiprocessing.Process(target=patroni_main) + patroni.start() + pid = patroni.pid + patroni.join() diff --git a/patroni-for-openGauss/__main__.py b/patroni-for-openGauss/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..3abcbfc37e36145fc1c3134c93076c950f52eaa1 --- /dev/null +++ b/patroni-for-openGauss/__main__.py @@ -0,0 +1,5 @@ +from patroni import main + + +if __name__ == '__main__': + main() diff --git a/patroni-for-openGauss/api.py b/patroni-for-openGauss/api.py new file mode 100644 index 0000000000000000000000000000000000000000..c738a2cc847a9161be7584e7b634091c7b435925 --- /dev/null +++ b/patroni-for-openGauss/api.py @@ -0,0 +1,702 @@ +import base64 +import hmac +import json +import logging +import psycopg2 +import time +import traceback +import dateutil.parser +import datetime +import os +import six +import socket +import sys + +from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +from six.moves.socketserver import ThreadingMixIn +from six.moves.urllib_parse import urlparse, parse_qs +from threading import Thread + +from .exceptions import PostgresConnectionException, PostgresException +from .postgresql.misc import postgres_version_to_int +from .utils import deep_compare, enable_keepalive, parse_bool, patch_config, Retry, \ + RetryFailedError, parse_int, split_host_port, tzutc, uri, cluster_as_json + +logger = logging.getLogger(__name__) + + +class RestApiHandler(BaseHTTPRequestHandler): + + def _write_status_code_only(self, status_code): + message = self.responses[status_code][0] + self.wfile.write('{0} {1} {2}\r\n\r\n'.format(self.protocol_version, status_code, message).encode('utf-8')) + self.log_request(status_code) + + def _write_response(self, status_code, body, content_type='text/html', headers=None): + self.send_response(status_code) + headers = headers or {} + if content_type: + headers['Content-Type'] = content_type + for name, value in headers.items(): + self.send_header(name, value) + for name, value in self.server.http_extra_headers.items(): + self.send_header(name, value) + self.end_headers() + self.wfile.write(body.encode('utf-8')) + + def _write_json_response(self, status_code, response): + self._write_response(status_code, json.dumps(response), content_type='application/json') + + def check_auth(func): + """Decorator function to check authorization header or client certificates + + Usage example: + @check_auth + def do_PUT_foo(): + pass + """ + + def wrapper(self, *args, **kwargs): + if self.server.check_auth(self): + return func(self, *args, **kwargs) + + return wrapper + + def _write_status_response(self, status_code, response): + patroni = self.server.patroni + tags = patroni.ha.get_effective_tags() + if tags: + response['tags'] = tags + if patroni.postgresql.sysid: + response['database_system_identifier'] = patroni.postgresql.sysid + if patroni.postgresql.pending_restart: + response['pending_restart'] = True + response['patroni'] = {'version': patroni.version, 'scope': patroni.postgresql.scope} + if patroni.scheduled_restart and isinstance(patroni.scheduled_restart, dict): + response['scheduled_restart'] = patroni.scheduled_restart.copy() + del response['scheduled_restart']['postmaster_start_time'] + response['scheduled_restart']['schedule'] = (response['scheduled_restart']['schedule']).isoformat() + if not patroni.ha.watchdog.is_healthy: + response['watchdog_failed'] = True + if patroni.ha.is_paused(): + response['pause'] = True + qsize = patroni.logger.queue_size + if qsize > patroni.logger.NORMAL_LOG_QUEUE_SIZE: + response['logger_queue_size'] = qsize + lost = patroni.logger.records_lost + if lost: + response['logger_records_lost'] = lost + self._write_json_response(status_code, response) + + def do_GET(self, write_status_code_only=False): + """Default method for processing all GET requests which can not be routed to other methods""" + + path = '/master' if self.path == '/' else self.path + response = self.get_postgresql_status() + + patroni = self.server.patroni + cluster = patroni.dcs.cluster + + leader_optime = cluster and cluster.last_leader_operation or 0 + replayed_location = response.get('xlog', {}).get('replayed_location', 0) + max_replica_lag = parse_int(self.path_query.get('lag', [sys.maxsize])[0], 'B') + if max_replica_lag is None: + max_replica_lag = sys.maxsize + is_lagging = leader_optime and leader_optime > replayed_location + max_replica_lag + + replica_status_code = 200 if not patroni.noloadbalance and not is_lagging and \ + response.get('role') == 'replica' and response.get('state') == 'running' else 503 + + if not cluster and patroni.ha.is_paused(): + primary_status_code = 200 if response.get('role') == 'master' else 503 + standby_leader_status_code = 200 if response.get('role') == 'standby_leader' else 503 + elif patroni.ha.is_leader(): + if patroni.ha.is_standby_cluster(): + primary_status_code = replica_status_code = 503 + standby_leader_status_code = 200 if response.get('role') in ('replica', 'standby_leader') else 503 + else: + primary_status_code = 200 + standby_leader_status_code = 503 + else: + primary_status_code = standby_leader_status_code = 503 + + status_code = 503 + + if 'standby_leader' in path or 'standby-leader' in path: + status_code = standby_leader_status_code + elif 'master' in path or 'leader' in path or 'primary' in path or 'read-write' in path: + status_code = primary_status_code + elif 'replica' in path: + status_code = replica_status_code + elif 'read-only' in path: + status_code = 200 if 200 in (primary_status_code, standby_leader_status_code) else replica_status_code + elif 'health' in path: + status_code = 200 if response.get('state') == 'running' else 503 + elif cluster: # dcs is available + is_synchronous = cluster.is_synchronous_mode() and cluster.sync \ + and patroni.postgresql.name in cluster.sync.members + if path in ('/sync', '/synchronous') and is_synchronous: + status_code = replica_status_code + elif path in ('/async', '/asynchronous') and not is_synchronous: + status_code = replica_status_code + + if write_status_code_only: # when haproxy sends OPTIONS request it reads only status code and nothing more + self._write_status_code_only(status_code) + else: + self._write_status_response(status_code, response) + + def do_OPTIONS(self): + self.do_GET(write_status_code_only=True) + + def do_GET_liveness(self): + self._write_status_code_only(200) + + def do_GET_readiness(self): + patroni = self.server.patroni + if patroni.ha.is_leader(): + status_code = 200 + elif patroni.postgresql.state == 'running': + status_code = 200 if patroni.dcs.cluster else 503 + else: + status_code = 503 + self._write_status_code_only(status_code) + + def do_GET_patroni(self): + response = self.get_postgresql_status(True) + self._write_status_response(200, response) + + def do_GET_cluster(self): + cluster = self.server.patroni.dcs.get_cluster(True) + self._write_json_response(200, cluster_as_json(cluster)) + + def do_GET_history(self): + cluster = self.server.patroni.dcs.cluster or self.server.patroni.dcs.get_cluster() + self._write_json_response(200, cluster.history and cluster.history.lines or []) + + def do_GET_config(self): + cluster = self.server.patroni.dcs.cluster or self.server.patroni.dcs.get_cluster() + if cluster.config: + self._write_json_response(200, cluster.config.data) + else: + self.send_error(502) + + def _read_json_content(self, body_is_optional=False): + if 'content-length' not in self.headers: + return self.send_error(411) if not body_is_optional else {} + try: + content_length = int(self.headers.get('content-length')) + if content_length == 0 and body_is_optional: + return {} + request = json.loads(self.rfile.read(content_length).decode('utf-8')) + if isinstance(request, dict) and (request or body_is_optional): + return request + except Exception: + logger.exception('Bad request') + self.send_error(400) + + @check_auth + def do_PATCH_config(self): + request = self._read_json_content() + if request: + cluster = self.server.patroni.dcs.get_cluster(True) + if not (cluster.config and cluster.config.modify_index): + return self.send_error(503) + data = cluster.config.data.copy() + if patch_config(data, request): + value = json.dumps(data, separators=(',', ':')) + if not self.server.patroni.dcs.set_config_value(value, cluster.config.index): + return self.send_error(409) + self.server.patroni.ha.wakeup() + self._write_json_response(200, data) + + @check_auth + def do_PUT_config(self): + request = self._read_json_content() + if request: + cluster = self.server.patroni.dcs.get_cluster() + if not deep_compare(request, cluster.config.data): + value = json.dumps(request, separators=(',', ':')) + if not self.server.patroni.dcs.set_config_value(value): + return self.send_error(502) + self._write_json_response(200, request) + + @check_auth + def do_POST_reload(self): + self.server.patroni.sighup_handler() + self._write_response(202, 'reload scheduled') + + @staticmethod + def parse_schedule(schedule, action): + """ parses the given schedule and validates at """ + error = None + scheduled_at = None + try: + scheduled_at = dateutil.parser.parse(schedule) + if scheduled_at.tzinfo is None: + error = 'Timezone information is mandatory for the scheduled {0}'.format(action) + status_code = 400 + elif scheduled_at < datetime.datetime.now(tzutc): + error = 'Cannot schedule {0} in the past'.format(action) + status_code = 422 + else: + status_code = None + except (ValueError, TypeError): + logger.exception('Invalid scheduled %s time: %s', action, schedule) + error = 'Unable to parse scheduled timestamp. It should be in an unambiguous format, e.g. ISO 8601' + status_code = 422 + return (status_code, error, scheduled_at) + + @check_auth + def do_POST_restart(self): + status_code = 500 + data = 'restart failed' + request = self._read_json_content(body_is_optional=True) + cluster = self.server.patroni.dcs.get_cluster() + if request is None: + # failed to parse the json + return + if request: + logger.debug("received restart request: {0}".format(request)) + + if cluster.is_paused() and 'schedule' in request: + self._write_response(status_code, "Can't schedule restart in the paused state") + return + + for k in request: + if k == 'schedule': + (_, data, request[k]) = self.parse_schedule(request[k], "restart") + if _: + status_code = _ + break + elif k == 'role': + if request[k] not in ('master', 'replica'): + status_code = 400 + data = "PostgreSQL role should be either master or replica" + break + elif k == 'postgres_version': + try: + postgres_version_to_int(request[k]) + except PostgresException as e: + status_code = 400 + data = e.value + break + elif k == 'timeout': + request[k] = parse_int(request[k], 's') + if request[k] is None or request[k] <= 0: + status_code = 400 + data = "Timeout should be a positive number of seconds" + break + elif k != 'restart_pending': + status_code = 400 + data = "Unknown filter for the scheduled restart: {0}".format(k) + break + else: + if 'schedule' not in request: + try: + status, data = self.server.patroni.ha.restart(request) + status_code = 200 if status else 503 + except Exception: + logger.exception('Exception during restart') + status_code = 400 + else: + if self.server.patroni.ha.schedule_future_restart(request): + data = "Restart scheduled" + status_code = 202 + else: + data = "Another restart is already scheduled" + status_code = 409 + self._write_response(status_code, data) + + @check_auth + def do_DELETE_restart(self): + if self.server.patroni.ha.delete_future_restart(): + data = "scheduled restart deleted" + code = 200 + else: + data = "no restarts are scheduled" + code = 404 + self._write_response(code, data) + + @check_auth + def do_DELETE_switchover(self): + failover = self.server.patroni.dcs.get_cluster().failover + if failover and failover.scheduled_at: + if not self.server.patroni.dcs.manual_failover('', '', index=failover.index): + return self.send_error(409) + else: + data = "scheduled switchover deleted" + code = 200 + else: + data = "no switchover is scheduled" + code = 404 + self._write_response(code, data) + + @check_auth + def do_POST_reinitialize(self): + request = self._read_json_content(body_is_optional=True) + + if request: + logger.debug('received reinitialize request: %s', request) + + force = isinstance(request, dict) and parse_bool(request.get('force')) or False + + data = self.server.patroni.ha.reinitialize(force) + if data is None: + status_code = 200 + data = 'reinitialize started' + else: + status_code = 503 + self._write_response(status_code, data) + + def poll_failover_result(self, leader, candidate, action): + timeout = max(10, self.server.patroni.dcs.loop_wait) + for _ in range(0, timeout*2): + time.sleep(1) + try: + cluster = self.server.patroni.dcs.get_cluster() + if not cluster.is_unlocked() and cluster.leader.name != leader: + if not candidate or candidate == cluster.leader.name: + return 200, 'Successfully {0}ed over to "{1}"'.format(action[:-4], cluster.leader.name) + else: + return 200, '{0}ed over to "{1}" instead of "{2}"'.format(action[:-4].title(), + cluster.leader.name, candidate) + if not cluster.failover: + return 503, action.title() + ' failed' + except Exception as e: + logger.debug('Exception occured during polling %s result: %s', action, e) + return 503, action.title() + ' status unknown' + + def is_failover_possible(self, cluster, leader, candidate, action): + if leader and (not cluster.leader or cluster.leader.name != leader): + return 'leader name does not match' + if candidate: + if action == 'switchover' and cluster.is_synchronous_mode() and candidate not in cluster.sync.members: + return 'candidate name does not match with sync_standby' + members = [m for m in cluster.members if m.name == candidate] + if not members: + return 'candidate does not exists' + elif cluster.is_synchronous_mode(): + members = [m for m in cluster.members if m.name in cluster.sync.members] + if not members: + return action + ' is not possible: can not find sync_standby' + else: + members = [m for m in cluster.members if m.name != cluster.leader.name and m.api_url] + if not members: + return action + ' is not possible: cluster does not have members except leader' + for st in self.server.patroni.ha.fetch_nodes_statuses(members): + if st.failover_limitation() is None: + return None + return action + ' is not possible: no good candidates have been found' + + @check_auth + def do_POST_failover(self, action='failover'): + request = self._read_json_content() + (status_code, data) = (400, '') + if not request: + return + + leader = request.get('leader') + candidate = request.get('candidate') or request.get('member') + scheduled_at = request.get('scheduled_at') + cluster = self.server.patroni.dcs.get_cluster() + + logger.info("received %s request with leader=%s candidate=%s scheduled_at=%s", + action, leader, candidate, scheduled_at) + + if action == 'failover' and not candidate: + data = 'Failover could be performed only to a specific candidate' + elif action == 'switchover' and not leader: + data = 'Switchover could be performed only from a specific leader' + + if not data and scheduled_at: + if not leader: + data = 'Scheduled {0} is possible only from a specific leader'.format(action) + if not data and cluster.is_paused(): + data = "Can't schedule {0} in the paused state".format(action) + if not data: + (status_code, data, scheduled_at) = self.parse_schedule(scheduled_at, action) + + if not data and cluster.is_paused() and not candidate: + data = action.title() + ' is possible only to a specific candidate in a paused state' + + if not data and not scheduled_at: + data = self.is_failover_possible(cluster, leader, candidate, action) + if data: + status_code = 412 + + if not data: + if self.server.patroni.dcs.manual_failover(leader, candidate, scheduled_at=scheduled_at): + self.server.patroni.ha.wakeup() + if scheduled_at: + data = action.title() + ' scheduled' + status_code = 202 + else: + status_code, data = self.poll_failover_result(cluster.leader and cluster.leader.name, + candidate, action) + else: + data = 'failed to write {0} key into DCS'.format(action) + status_code = 503 + self._write_response(status_code, data) + + def do_POST_switchover(self): + self.do_POST_failover(action='switchover') + + def parse_request(self): + """Override parse_request method to enrich basic functionality of `BaseHTTPRequestHandler` class + + Original class can only invoke do_GET, do_POST, do_PUT, etc method implementations if they are defined. + But we would like to have at least some simple routing mechanism, i.e.: + GET /uri1/part2 request should invoke `do_GET_uri1()` + POST /other should invoke `do_POST_other()` + + If the `do__` method does not exists we'll fallback to original behavior.""" + + ret = BaseHTTPRequestHandler.parse_request(self) + if ret: + urlpath = urlparse(self.path) + self.path = urlpath.path + self.path_query = parse_qs(urlpath.query) or {} + mname = self.path.lstrip('/').split('/')[0] + mname = self.command + ('_' + mname if mname else '') + if hasattr(self, 'do_' + mname): + self.command = mname + return ret + + def query(self, sql, *params, **kwargs): + if not kwargs.get('retry', False): + return self.server.query(sql, *params) + retry = Retry(delay=1, retry_exceptions=PostgresConnectionException) + return retry(self.server.query, sql, *params) + + def get_postgresql_status(self, retry=False): + postgresql = self.server.patroni.postgresql + try: + cluster = self.server.patroni.dcs.cluster + + if postgresql.state not in ('running', 'restarting', 'starting'): + raise RetryFailedError('') + stmt = ("SELECT " + postgresql.POSTMASTER_START_TIME + ", " + postgresql.TL_LSN + "," + " pg_catalog.to_char(pg_catalog.pg_last_xact_replay_timestamp(), 'YYYY-MM-DD HH24:MI:SS.MS TZ')," + " pg_catalog.array_to_json(pg_catalog.array_agg(pg_catalog.row_to_json(ri))) " + "FROM (SELECT (SELECT rolname FROM pg_authid WHERE oid = %s) AS usename," + " application_name, client_addr, w.state, sync_state, sync_priority" + " FROM pg_catalog.pg_stat_get_wal_senders() w, pg_catalog.pg_stat_get_activity(%s)) AS ri") + init_user_oid = 10 + gaussdb_pid = os.popen('ps x | grep gaussdb | grep -v grep').readlines()[0].split()[0] + + row = self.query(stmt.format(postgresql.wal_name, postgresql.lsn_name), + init_user_oid, gaussdb_pid, retry=retry)[0] + + result = { + 'state': postgresql.state, + 'postmaster_start_time': row[0], + 'role': 'replica' if row[1] == 0 else 'master', + 'server_version': postgresql.server_version, + 'cluster_unlocked': bool(not cluster or cluster.is_unlocked()), + 'xlog': ({ + 'received_location': row[4] or row[3], + 'replayed_location': row[3], + 'replayed_timestamp': row[6], + 'paused': row[5]} if row[1] == 0 else { + 'location': row[2] + }) + } + + if result['role'] == 'replica' and self.server.patroni.ha.is_standby_cluster(): + result['role'] = postgresql.role + + if row[1] > 0: + result['timeline'] = row[1] + else: + leader_timeline = None if not cluster or cluster.is_unlocked() else cluster.leader.timeline + result['timeline'] = postgresql.replica_cached_timeline(leader_timeline) + + if row[7]: + result['replication'] = row[7] + + return result + except (psycopg2.Error, RetryFailedError, PostgresConnectionException): + state = postgresql.state + if state == 'running': + logger.exception('get_postgresql_status') + state = 'unknown' + return {'state': state, 'role': postgresql.role} + + def handle_one_request(self): + self.__start_time = time.time() + BaseHTTPRequestHandler.handle_one_request(self) + + def log_message(self, fmt, *args): + latency = 1000.0 * (time.time() - self.__start_time) + logger.debug("API thread: %s - - %s latency: %0.3f ms", self.client_address[0], fmt % args, latency) + + +class RestApiServer(ThreadingMixIn, HTTPServer, Thread): + # On 3.7+ the `ThreadingMixIn` gathers all non-daemon worker threads in order to join on them at server close. + daemon_threads = True # Make worker threads "fire and forget" to prevent a memory leak. + + def __init__(self, patroni, config): + self.patroni = patroni + self.__listen = None + self.__ssl_options = None + self.http_extra_headers = {} + self.reload_config(config) + self.daemon = True + + def query(self, sql, *params): + cursor = None + try: + with self.patroni.postgresql.connection().cursor() as cursor: + cursor.execute(sql, params) + return [r for r in cursor] + except psycopg2.Error as e: + if cursor and cursor.connection.closed == 0: + raise e + raise PostgresConnectionException('connection problems') + + @staticmethod + def _set_fd_cloexec(fd): + if os.name != 'nt': + import fcntl + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + def check_basic_auth_key(self, key): + return hmac.compare_digest(self.__auth_key, key.encode('utf-8')) + + def check_auth_header(self, auth_header): + if self.__auth_key: + if auth_header is None: + return 'no auth header received' + if not auth_header.startswith('Basic ') or not self.check_basic_auth_key(auth_header[6:]): + return 'not authenticated' + + def check_auth(self, rh): + if not hasattr(rh.request, 'getpeercert') or not rh.request.getpeercert(): # valid client cert isn't present + if self.__protocol == 'https' and self.__ssl_options.get('verify_client') in ('required', 'optional'): + return rh._write_response(403, 'client certificate required') + + reason = self.check_auth_header(rh.headers.get('Authorization')) + if reason: + headers = {'WWW-Authenticate': 'Basic realm="' + self.patroni.__class__.__name__ + '"'} + return rh._write_response(401, reason, headers=headers) + return True + + @staticmethod + def __has_dual_stack(): + if hasattr(socket, 'AF_INET6') and hasattr(socket, 'IPPROTO_IPV6') and hasattr(socket, 'IPV6_V6ONLY'): + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + return True + except socket.error as e: + logger.debug('Error when working with ipv6 socket: %s', e) + finally: + if sock: + sock.close() + return False + + def __httpserver_init(self, host, port): + dual_stack = self.__has_dual_stack() + if host in ('', '*'): + host = None + + info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + # in case dual stack is not supported we want IPv4 to be preferred over IPv6 + info.sort(key=lambda x: x[0] == socket.AF_INET, reverse=not dual_stack) + + self.address_family = info[0][0] + try: + HTTPServer.__init__(self, info[0][-1][:2], RestApiHandler) + except socket.error: + logger.error( + "Couldn't start a service on '%s:%s', please check your `restapi.listen` configuration", host, port) + raise + + def __initialize(self, listen, ssl_options): + try: + host, port = split_host_port(listen, None) + except Exception: + raise ValueError('Invalid "restapi" config: expected : for "listen", but got "{0}"' + .format(listen)) + + reloading_config = self.__listen is not None # changing config in runtime + if reloading_config: + self.shutdown() + + self.__listen = listen + self.__ssl_options = ssl_options + + self.__httpserver_init(host, port) + Thread.__init__(self, target=self.serve_forever) + self._set_fd_cloexec(self.socket) + + # wrap socket with ssl if 'certfile' is defined in a config.yaml + # Sometime it's also needed to pass reference to a 'keyfile'. + self.__protocol = 'https' if ssl_options.get('certfile') else 'http' + if self.__protocol == 'https': + import ssl + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=ssl_options.get('cafile')) + if ssl_options.get('ciphers'): + ctx.set_ciphers(ssl_options['ciphers']) + ctx.load_cert_chain(certfile=ssl_options['certfile'], keyfile=ssl_options.get('keyfile'), + password=ssl_options.get('keyfile_password')) + verify_client = ssl_options.get('verify_client') + if verify_client: + modes = {'none': ssl.CERT_NONE, 'optional': ssl.CERT_OPTIONAL, 'required': ssl.CERT_REQUIRED} + if verify_client in modes: + ctx.verify_mode = modes[verify_client] + else: + logger.error('Bad value in the "restapi.verify_client": %s', verify_client) + self.socket = ctx.wrap_socket(self.socket, server_side=True) + if reloading_config: + self.start() + + def process_request_thread(self, request, client_address): + if isinstance(request, tuple): + sock, newsock = request + try: + request = sock.context.wrap_socket(newsock, do_handshake_on_connect=sock.do_handshake_on_connect, + suppress_ragged_eofs=sock.suppress_ragged_eofs, server_side=True) + except socket.error: + return + super(RestApiServer, self).process_request_thread(request, client_address) + + def get_request(self): + sock = self.socket + newsock, addr = socket.socket.accept(sock) + enable_keepalive(newsock, 10, 3) + if hasattr(sock, 'context'): # SSLSocket, we want to do the deferred handshake from a thread + newsock = (sock, newsock) + return newsock, addr + + def shutdown_request(self, request): + if isinstance(request, tuple): + _, request = request # SSLSocket + return super(RestApiServer, self).shutdown_request(request) + + def reload_config(self, config): + if 'listen' not in config: # changing config in runtime + raise ValueError('Can not find "restapi.listen" config') + + ssl_options = {n: config[n] for n in ('certfile', 'keyfile', 'keyfile_password', + 'cafile', 'ciphers') if n in config} + + self.http_extra_headers = config.get('http_extra_headers') or {} + self.http_extra_headers.update((config.get('https_extra_headers') or {}) if ssl_options.get('certfile') else {}) + + if isinstance(config.get('verify_client'), six.string_types): + ssl_options['verify_client'] = config['verify_client'].lower() + + if self.__listen != config['listen'] or self.__ssl_options != ssl_options: + self.__initialize(config['listen'], ssl_options) + + self.__auth_key = base64.b64encode(config['auth'].encode('utf-8')) if 'auth' in config else None + self.connection_string = uri(self.__protocol, config.get('connect_address') or self.__listen, 'patroni') + + @staticmethod + def handle_error(request, client_address): + address, port = client_address + logger.warning('Exception happened during processing of request from {}:{}'.format(address, port)) + logger.warning(traceback.format_exc()) diff --git a/patroni-for-openGauss/async_executor.py b/patroni-for-openGauss/async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..46194fedcba843e43d617b847196ca816e6681f5 --- /dev/null +++ b/patroni-for-openGauss/async_executor.py @@ -0,0 +1,137 @@ +import logging +from threading import Event, Lock, RLock, Thread + +logger = logging.getLogger(__name__) + + +class CriticalTask(object): + """Represents a critical task in a background process that we either need to cancel or get the result of. + + Fields of this object may be accessed only when holding a lock on it. To perform the critical task the background + thread must, while holding lock on this object, check `is_cancelled` flag, run the task and mark the task as + complete using `complete()`. + + The main thread must hold async lock to prevent the task from completing, hold lock on critical task object, + call cancel. If the task has completed `cancel()` will return False and `result` field will contain the result of + the task. When cancel returns True it is guaranteed that the background task will notice the `is_cancelled` flag. + """ + def __init__(self): + self._lock = Lock() + self.is_cancelled = False + self.result = None + + def reset(self): + """Must be called every time the background task is finished. + + Must be called from async thread. Caller must hold lock on async executor when calling.""" + self.is_cancelled = False + self.result = None + + def cancel(self): + """Tries to cancel the task, returns True if the task has already run. + + Caller must hold lock on async executor and the task when calling.""" + if self.result is not None: + return False + self.is_cancelled = True + return True + + def complete(self, result): + """Mark task as completed along with a result. + + Must be called from async thread. Caller must hold lock on task when calling.""" + self.result = result + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._lock.release() + + +class AsyncExecutor(object): + + def __init__(self, cancellable, ha_wakeup): + self._cancellable = cancellable + self._ha_wakeup = ha_wakeup + self._thread_lock = RLock() + self._scheduled_action = None + self._scheduled_action_lock = RLock() + self._is_cancelled = False + self._finish_event = Event() + self.critical_task = CriticalTask() + + @property + def busy(self): + return self.scheduled_action is not None + + def schedule(self, action): + with self._scheduled_action_lock: + if self._scheduled_action is not None: + return self._scheduled_action + self._scheduled_action = action + self._is_cancelled = False + self._finish_event.set() + return None + + @property + def scheduled_action(self): + with self._scheduled_action_lock: + return self._scheduled_action + + def reset_scheduled_action(self): + with self._scheduled_action_lock: + self._scheduled_action = None + + def run(self, func, args=()): + wakeup = False + try: + with self: + if self._is_cancelled: + return + self._finish_event.clear() + + self._cancellable.reset_is_cancelled() + # if the func returned something (not None) - wake up main HA loop + wakeup = func(*args) if args else func() + return wakeup + except Exception: + logger.exception('Exception during execution of long running task %s', self.scheduled_action) + finally: + with self: + self.reset_scheduled_action() + self._finish_event.set() + with self.critical_task: + self.critical_task.reset() + if wakeup is not None: + self._ha_wakeup() + + def run_async(self, func, args=()): + Thread(target=self.run, args=(func, args)).start() + + def try_run_async(self, action, func, args=()): + prev = self.schedule(action) + if prev is None: + return self.run_async(func, args) + return 'Failed to run {0}, {1} is already in progress'.format(action, prev) + + def cancel(self): + with self: + with self._scheduled_action_lock: + if self._scheduled_action is None: + return + logger.warning('Cancelling long running task %s', self._scheduled_action) + self._is_cancelled = True + + self._cancellable.cancel() + self._finish_event.wait() + + with self: + self.reset_scheduled_action() + + def __enter__(self): + self._thread_lock.acquire() + + def __exit__(self, *args): + self._thread_lock.release() diff --git a/patroni-for-openGauss/config.py b/patroni-for-openGauss/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6c73b017c865191e016dee62b56a6a84c20f88 --- /dev/null +++ b/patroni-for-openGauss/config.py @@ -0,0 +1,434 @@ +import json +import logging +import os +import shutil +import tempfile +import yaml + +from collections import defaultdict +from copy import deepcopy +from patroni import PATRONI_ENV_PREFIX +from patroni.exceptions import ConfigParseError +from patroni.dcs import ClusterConfig +from patroni.postgresql.config import CaseInsensitiveDict, ConfigHandler +from patroni.utils import deep_compare, parse_bool, parse_int, patch_config + +logger = logging.getLogger(__name__) + +_AUTH_ALLOWED_PARAMETERS = ( + 'username', + 'password', + 'sslmode', + 'sslcert', + 'sslkey', + 'sslpassword', + 'sslrootcert', + 'sslcrl', + 'gssencmode', + 'channel_binding' +) + + +def default_validator(conf): + if not conf: + return "Config is empty." + + +class Config(object): + """ + This class is responsible for: + + 1) Building and giving access to `effective_configuration` from: + * `Config.__DEFAULT_CONFIG` -- some sane default values + * `dynamic_configuration` -- configuration stored in DCS + * `local_configuration` -- configuration from `config.yml` or environment + + 2) Saving and loading `dynamic_configuration` into 'patroni.dynamic.json' file + located in local_configuration['postgresql']['data_dir'] directory. + This is necessary to be able to restore `dynamic_configuration` + if DCS was accidentally wiped + + 3) Loading of configuration file in the old format and converting it into new format + + 4) Mimicking some of the `dict` interfaces to make it possible + to work with it as with the old `config` object. + """ + + PATRONI_CONFIG_VARIABLE = PATRONI_ENV_PREFIX + 'CONFIGURATION' + + __CACHE_FILENAME = 'patroni.dynamic.json' + __DEFAULT_CONFIG = { + 'ttl': 30, 'loop_wait': 10, 'retry_timeout': 10, + 'maximum_lag_on_failover': 1048576, + 'maximum_lag_on_syncnode': -1, + 'check_timeline': False, + 'master_start_timeout': 300, + 'master_stop_timeout': 0, + 'synchronous_mode': False, + 'synchronous_mode_strict': False, + 'synchronous_node_count': 1, + 'standby_cluster': { + 'create_replica_methods': '', + 'host': '', + 'port': '', + 'primary_slot_name': '', + 'restore_command': '', + 'archive_cleanup_command': '', + 'recovery_min_apply_delay': '' + }, + 'postgresql': { + 'bin_dir': '', + 'use_slots': True, + 'parameters': CaseInsensitiveDict({p: v[0] for p, v in ConfigHandler.CMDLINE_OPTIONS.items() + if p not in ('wal_keep_segments', 'wal_keep_size')}) + }, + 'watchdog': { + 'mode': 'automatic', + } + } + + def __init__(self, configfile, validator=default_validator): + self._modify_index = -1 + self._dynamic_configuration = {} + + self.__environment_configuration = self._build_environment_configuration() + + # Patroni reads the configuration from the command-line argument if it exists, otherwise from the environment + self._config_file = configfile and os.path.exists(configfile) and configfile + if self._config_file: + self._local_configuration = self._load_config_file() + else: + config_env = os.environ.pop(self.PATRONI_CONFIG_VARIABLE, None) + self._local_configuration = config_env and yaml.safe_load(config_env) or self.__environment_configuration + if validator: + error = validator(self._local_configuration) + if error: + raise ConfigParseError(error) + + self.__effective_configuration = self._build_effective_configuration({}, self._local_configuration) + self._data_dir = self.__effective_configuration.get('postgresql', {}).get('data_dir', "") + self._cache_file = os.path.join(self._data_dir, self.__CACHE_FILENAME) + self._load_cache() + self._cache_needs_saving = False + + @property + def config_file(self): + return self._config_file + + @property + def dynamic_configuration(self): + return deepcopy(self._dynamic_configuration) + + def check_mode(self, mode): + return bool(parse_bool(self._dynamic_configuration.get(mode))) + + def _load_config_path(self, path): + """ + If path is a file, loads the yml file pointed to by path. + If path is a directory, loads all yml files in that directory in alphabetical order + """ + if os.path.isfile(path): + files = [path] + elif os.path.isdir(path): + files = [os.path.join(path, f) for f in sorted(os.listdir(path)) + if (f.endswith('.yml') or f.endswith('.yaml')) and os.path.isfile(os.path.join(path, f))] + else: + logger.error('config path %s is neither directory nor file', path) + raise ConfigParseError('invalid config path') + + overall_config = {} + for fname in files: + with open(fname) as f: + config = yaml.safe_load(f) + patch_config(overall_config, config) + return overall_config + + def _load_config_file(self): + """Loads config.yaml from filesystem and applies some values which were set via ENV""" + config = self._load_config_path(self._config_file) + patch_config(config, self.__environment_configuration) + return config + + def _load_cache(self): + if os.path.isfile(self._cache_file): + try: + with open(self._cache_file) as f: + self.set_dynamic_configuration(json.load(f)) + except Exception: + logger.exception('Exception when loading file: %s', self._cache_file) + + def save_cache(self): + if self._cache_needs_saving: + tmpfile = fd = None + try: + (fd, tmpfile) = tempfile.mkstemp(prefix=self.__CACHE_FILENAME, dir=self._data_dir) + with os.fdopen(fd, 'w') as f: + fd = None + json.dump(self.dynamic_configuration, f) + tmpfile = shutil.move(tmpfile, self._cache_file) + self._cache_needs_saving = False + except Exception: + logger.exception('Exception when saving file: %s', self._cache_file) + if fd: + try: + os.close(fd) + except Exception: + logger.error('Can not close temporary file %s', tmpfile) + if tmpfile and os.path.exists(tmpfile): + try: + os.remove(tmpfile) + except Exception: + logger.error('Can not remove temporary file %s', tmpfile) + + # configuration could be either ClusterConfig or dict + def set_dynamic_configuration(self, configuration): + if isinstance(configuration, ClusterConfig): + if self._modify_index == configuration.modify_index: + return False # If the index didn't changed there is nothing to do + self._modify_index = configuration.modify_index + configuration = configuration.data + + if not deep_compare(self._dynamic_configuration, configuration): + try: + self.__effective_configuration = self._build_effective_configuration(configuration, + self._local_configuration) + self._dynamic_configuration = configuration + self._cache_needs_saving = True + return True + except Exception: + logger.exception('Exception when setting dynamic_configuration') + + def reload_local_configuration(self): + if self.config_file: + try: + configuration = self._load_config_file() + if not deep_compare(self._local_configuration, configuration): + new_configuration = self._build_effective_configuration(self._dynamic_configuration, configuration) + self._local_configuration = configuration + self.__effective_configuration = new_configuration + return True + else: + logger.info('No local configuration items changed.') + except Exception: + logger.exception('Exception when reloading local configuration from %s', self.config_file) + + @staticmethod + def _process_postgresql_parameters(parameters, is_local=False): + return {name: value for name, value in (parameters or {}).items() + if name not in ConfigHandler.CMDLINE_OPTIONS or + not is_local and ConfigHandler.CMDLINE_OPTIONS[name][1](value)} + + def _safe_copy_dynamic_configuration(self, dynamic_configuration): + config = deepcopy(self.__DEFAULT_CONFIG) + + for name, value in dynamic_configuration.items(): + if name == 'postgresql': + for name, value in (value or {}).items(): + if name == 'parameters': + config['postgresql'][name].update(self._process_postgresql_parameters(value)) + elif name not in ('connect_address', 'listen', 'data_dir', 'pgpass', 'authentication'): + config['postgresql'][name] = deepcopy(value) + elif name == 'standby_cluster': + for name, value in (value or {}).items(): + if name in self.__DEFAULT_CONFIG['standby_cluster']: + config['standby_cluster'][name] = deepcopy(value) + elif name in config: # only variables present in __DEFAULT_CONFIG allowed to be overriden from DCS + if name in ('synchronous_mode', 'synchronous_mode_strict'): + config[name] = value + else: + config[name] = int(value) + return config + + @staticmethod + def _build_environment_configuration(): + ret = defaultdict(dict) + + def _popenv(name): + return os.environ.pop(PATRONI_ENV_PREFIX + name.upper(), None) + + for param in ('name', 'namespace', 'scope'): + value = _popenv(param) + if value: + ret[param] = value + + def _fix_log_env(name, oldname): + value = _popenv(oldname) + name = PATRONI_ENV_PREFIX + 'LOG_' + name.upper() + if value and name not in os.environ: + os.environ[name] = value + + for name, oldname in (('level', 'loglevel'), ('format', 'logformat'), ('dateformat', 'log_datefmt')): + _fix_log_env(name, oldname) + + def _set_section_values(section, params): + for param in params: + value = _popenv(section + '_' + param) + if value: + ret[section][param] = value + + _set_section_values('restapi', ['listen', 'connect_address', 'certfile', 'keyfile', 'keyfile_password', + 'cafile', 'ciphers', 'verify_client', 'http_extra_headers', + 'https_extra_headers']) + _set_section_values('ctl', ['insecure', 'cacert', 'certfile', 'keyfile']) + _set_section_values('postgresql', ['listen', 'connect_address', 'config_dir', 'data_dir', 'pgpass', 'bin_dir']) + _set_section_values('log', ['level', 'traceback_level', 'format', 'dateformat', 'max_queue_size', + 'dir', 'file_size', 'file_num', 'loggers']) + + def _parse_dict(value): + if not value.strip().startswith('{'): + value = '{{{0}}}'.format(value) + try: + return yaml.safe_load(value) + except Exception: + logger.exception('Exception when parsing dict %s', value) + return None + + value = ret.get('log', {}).pop('loggers', None) + if value: + value = _parse_dict(value) + if value: + ret['log']['loggers'] = value + + def _get_auth(name, params=None): + ret = {} + for param in params or _AUTH_ALLOWED_PARAMETERS[:2]: + value = _popenv(name + '_' + param) + if value: + ret[param] = value + return ret + + restapi_auth = _get_auth('restapi') + if restapi_auth: + ret['restapi']['authentication'] = restapi_auth + + authentication = {} + for user_type in ('replication', 'superuser', 'rewind'): + entry = _get_auth(user_type, _AUTH_ALLOWED_PARAMETERS) + if entry: + authentication[user_type] = entry + + if authentication: + ret['postgresql']['authentication'] = authentication + + def _parse_list(value): + if not (value.strip().startswith('-') or '[' in value): + value = '[{0}]'.format(value) + try: + return yaml.safe_load(value) + except Exception: + logger.exception('Exception when parsing list %s', value) + return None + + _set_section_values('raft', ['data_dir', 'self_addr', 'partner_addrs', 'password', 'bind_addr']) + if 'raft' in ret and 'partner_addrs' in ret['raft']: + ret['raft']['partner_addrs'] = _parse_list(ret['raft']['partner_addrs']) + + for param in list(os.environ.keys()): + if param.startswith(PATRONI_ENV_PREFIX): + # PATRONI_(ETCD|CONSUL|ZOOKEEPER|EXHIBITOR|...)_(HOSTS?|PORT|..) + name, suffix = (param[8:].split('_', 1) + [''])[:2] + if suffix in ('HOST', 'HOSTS', 'PORT', 'USE_PROXIES', 'PROTOCOL', 'SRV', 'URL', 'PROXY', + 'CACERT', 'CERT', 'KEY', 'VERIFY', 'TOKEN', 'CHECKS', 'DC', 'CONSISTENCY', + 'REGISTER_SERVICE', 'SERVICE_CHECK_INTERVAL', 'NAMESPACE', 'CONTEXT', + 'USE_ENDPOINTS', 'SCOPE_LABEL', 'ROLE_LABEL', 'POD_IP', 'PORTS', 'LABELS', + 'BYPASS_API_SERVICE', 'KEY_PASSWORD', 'USE_SSL') and name: + value = os.environ.pop(param) + if suffix == 'PORT': + value = value and parse_int(value) + elif suffix in ('HOSTS', 'PORTS', 'CHECKS'): + value = value and _parse_list(value) + elif suffix == 'LABELS': + value = _parse_dict(value) + elif suffix in ('USE_PROXIES', 'REGISTER_SERVICE', 'USE_ENDPOINTS', 'BYPASS_API_SERVICE'): + value = parse_bool(value) + if value: + ret[name.lower()][suffix.lower()] = value + for dcs in ('etcd', 'etcd3'): + if dcs in ret: + ret[dcs].update(_get_auth(dcs)) + + users = {} + for param in list(os.environ.keys()): + if param.startswith(PATRONI_ENV_PREFIX): + name, suffix = (param[8:].rsplit('_', 1) + [''])[:2] + # PATRONI__PASSWORD=, PATRONI__OPTIONS= + # CREATE USER "" WITH PASSWORD '' + if name and suffix == 'PASSWORD': + password = os.environ.pop(param) + if password: + users[name] = {'password': password} + options = os.environ.pop(param[:-9] + '_OPTIONS', None) + options = options and _parse_list(options) + if options: + users[name]['options'] = options + if users: + ret['bootstrap']['users'] = users + + return ret + + def _build_effective_configuration(self, dynamic_configuration, local_configuration): + config = self._safe_copy_dynamic_configuration(dynamic_configuration) + for name, value in local_configuration.items(): + if name == 'postgresql': + for name, value in (value or {}).items(): + if name == 'parameters': + config['postgresql'][name].update(self._process_postgresql_parameters(value, True)) + elif name != 'use_slots': # replication slots must be enabled/disabled globally + config['postgresql'][name] = deepcopy(value) + elif name not in config or name in ['watchdog']: + config[name] = deepcopy(value) if value else {} + + # restapi server expects to get restapi.auth = 'username:password' + if 'restapi' in config and 'authentication' in config['restapi']: + config['restapi']['auth'] = '{username}:{password}'.format(**config['restapi']['authentication']) + + # special treatment for old config + + # 'exhibitor' inside 'zookeeper': + if 'zookeeper' in config and 'exhibitor' in config['zookeeper']: + config['exhibitor'] = config['zookeeper'].pop('exhibitor') + config.pop('zookeeper') + + pg_config = config['postgresql'] + # no 'authentication' in 'postgresql', but 'replication' and 'superuser' + if 'authentication' not in pg_config: + pg_config['use_pg_rewind'] = 'pg_rewind' in pg_config + pg_config['authentication'] = {u: pg_config[u] for u in ('replication', 'superuser') if u in pg_config} + # no 'superuser' in 'postgresql'.'authentication' + if 'superuser' not in pg_config['authentication'] and 'pg_rewind' in pg_config: + pg_config['authentication']['superuser'] = pg_config['pg_rewind'] + + # handle setting additional connection parameters that may be available + # in the configuration file, such as SSL connection parameters + for name, value in pg_config['authentication'].items(): + pg_config['authentication'][name] = {n: v for n, v in value.items() if n in _AUTH_ALLOWED_PARAMETERS} + + # no 'name' in config + if 'name' not in config and 'name' in pg_config: + config['name'] = pg_config['name'] + + updated_fields = ( + 'name', + 'scope', + 'retry_timeout', + 'synchronous_mode', + 'synchronous_mode_strict', + 'synchronous_node_count', + 'maximum_lag_on_syncnode' + ) + + pg_config.update({p: config[p] for p in updated_fields if p in config}) + + return config + + def get(self, key, default=None): + return self.__effective_configuration.get(key, default) + + def __contains__(self, key): + return key in self.__effective_configuration + + def __getitem__(self, key): + return self.__effective_configuration[key] + + def copy(self): + return deepcopy(self.__effective_configuration) diff --git a/patroni-for-openGauss/ctl.py b/patroni-for-openGauss/ctl.py new file mode 100644 index 0000000000000000000000000000000000000000..77e80f12a09b8f29548233f4c5d38c8d45b5636c --- /dev/null +++ b/patroni-for-openGauss/ctl.py @@ -0,0 +1,1312 @@ +''' +Patroni Control +''' + +import click +import codecs +import datetime +import dateutil.parser +import dateutil.tz +import copy +import difflib +import io +import json +import logging +import os +import random +import six +import subprocess +import sys +import tempfile +import time +import yaml + +from click import ClickException +from collections import defaultdict +from contextlib import contextmanager +from prettytable import ALL, FRAME, PrettyTable +from six.moves.urllib_parse import urlparse + +try: + from ydiff import markup_to_pager, PatchStream +except ImportError: # pragma: no cover + from cdiff import markup_to_pager, PatchStream + +from .dcs import get_dcs as _get_dcs +from .exceptions import PatroniException +from .postgresql import Postgresql +from .postgresql.misc import postgres_version_to_int +from .utils import cluster_as_json, find_executable, patch_config, polling_loop +from .request import PatroniRequest +from .version import __version__ + +CONFIG_DIR_PATH = click.get_app_dir('patroni') +CONFIG_FILE_PATH = os.path.join(CONFIG_DIR_PATH, 'patronictl.yaml') +DCS_DEFAULTS = {'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"}, + 'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"}, + 'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"}, + 'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"}} + + +class PatroniCtlException(ClickException): + pass + + +class PatronictlPrettyTable(PrettyTable): + + def __init__(self, header, *args, **kwargs): + PrettyTable.__init__(self, *args, **kwargs) + self.__table_header = header + self.__hline_num = 0 + self.__hline = None + + def _is_first_hline(self): + return self.__hline_num == 0 + + def _set_hline(self, value): + self.__hline = value + + def _get_hline(self): + ret = self.__hline + + # Inject nice table header + if self._is_first_hline() and self.__table_header: + header = self.__table_header[:len(ret) - 2] + ret = "".join([ret[0], header, ret[1 + len(header):]]) + + self.__hline_num += 1 + return ret + + _hrule = property(_get_hline, _set_hline) + + +def parse_dcs(dcs): + if dcs is None: + return None + elif '//' not in dcs: + dcs = '//' + dcs + + parsed = urlparse(dcs) + scheme = parsed.scheme + port = int(parsed.port) if parsed.port else None + + if scheme == '': + scheme = ([k for k, v in DCS_DEFAULTS.items() if v['port'] == port] or ['etcd'])[0] + elif scheme not in DCS_DEFAULTS: + raise PatroniCtlException('Unknown dcs scheme: {}'.format(scheme)) + + default = DCS_DEFAULTS[scheme] + return yaml.safe_load(default['template'].format(host=parsed.hostname or 'localhost', port=port or default['port'])) + + +def load_config(path, dcs): + from patroni.config import Config + + if not (os.path.exists(path) and os.access(path, os.R_OK)): + if path != CONFIG_FILE_PATH: # bail if non-default config location specified but file not found / readable + raise PatroniCtlException('Provided config file {0} not existing or no read rights.' + ' Check the -c/--config-file parameter'.format(path)) + else: + logging.debug('Ignoring configuration file "%s". It does not exists or is not readable.', path) + else: + logging.debug('Loading configuration from file %s', path) + config = Config(path, validator=None).copy() + + dcs = parse_dcs(dcs) or parse_dcs(config.get('dcs_api')) or {} + if dcs: + for d in DCS_DEFAULTS: + config.pop(d, None) + config.update(dcs) + return config + + +def store_config(config, path): + dir_path = os.path.dirname(path) + if dir_path and not os.path.isdir(dir_path): + os.makedirs(dir_path) + with open(path, 'w') as fd: + yaml.dump(config, fd) + + +option_format = click.option('--format', '-f', 'fmt', help='Output format (pretty, tsv, json, yaml)', default='pretty') +option_watchrefresh = click.option('-w', '--watch', type=float, help='Auto update the screen every X seconds') +option_watch = click.option('-W', is_flag=True, help='Auto update the screen every 2 seconds') +option_force = click.option('--force', is_flag=True, help='Do not ask for confirmation at any point') +arg_cluster_name = click.argument('cluster_name', required=False, + default=lambda: click.get_current_context().obj.get('scope')) +option_insecure = click.option('-k', '--insecure', is_flag=True, help='Allow connections to SSL sites without certs') + + +@click.group() +@click.option('--config-file', '-c', help='Configuration file', + envvar='PATRONICTL_CONFIG_FILE', default=CONFIG_FILE_PATH) +@click.option('--dcs', '-d', help='Use this DCS', envvar='DCS') +@option_insecure +@click.pass_context +def ctl(ctx, config_file, dcs, insecure): + level = 'WARNING' + for name in ('LOGLEVEL', 'PATRONI_LOGLEVEL', 'PATRONI_LOG_LEVEL'): + level = os.environ.get(name, level) + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=level) + logging.captureWarnings(True) # Capture eventual SSL warning + ctx.obj = load_config(config_file, dcs) + # backward compatibility for configuration file where ctl section is not define + ctx.obj.setdefault('ctl', {})['insecure'] = ctx.obj.get('ctl', {}).get('insecure') or insecure + + +def get_dcs(config, scope): + config.update({'scope': scope, 'patronictl': True}) + config.setdefault('name', scope) + try: + return _get_dcs(config) + except PatroniException as e: + raise PatroniCtlException(str(e)) + + +def request_patroni(member, method='GET', endpoint=None, data=None): + ctx = click.get_current_context() # the current click context + request_executor = ctx.obj.get('__request_patroni') + if not request_executor: + request_executor = ctx.obj['__request_patroni'] = PatroniRequest(ctx.obj) + return request_executor(member, method, endpoint, data) + + +def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delimiter='\t'): + if fmt in {'json', 'yaml', 'yml'}: + elements = [{k: v for k, v in zip(columns, r) if not header or str(v)} for r in rows] + func = json.dumps if fmt == 'json' else format_config_for_editing + click.echo(func(elements)) + elif fmt in {'pretty', 'tsv', 'topology'}: + list_cluster = bool(header and columns and columns[0] == 'Cluster') + if list_cluster and 'Tags' in columns: # we want to format member tags as YAML + i = columns.index('Tags') + for row in rows: + if row[i]: + row[i] = format_config_for_editing(row[i], fmt != 'pretty').strip() + if list_cluster and fmt != 'tsv': # skip cluster name if pretty-printing + columns = columns[1:] if columns else [] + rows = [row[1:] for row in rows] + + if fmt == 'tsv': + for r in ([columns] if columns else []) + rows: + click.echo(delimiter.join(map(str, r))) + else: + hrules = ALL if any(any(isinstance(c, six.string_types) and '\n' in c for c in r) for r in rows) else FRAME + table = PatronictlPrettyTable(header, columns, hrules=hrules) + table.align = 'l' + for k, v in (alignment or {}).items(): + table.align[k] = v + for r in rows: + table.add_row(r) + click.echo(table) + + +def watching(w, watch, max_count=None, clear=True): + """ + >>> len(list(watching(True, 1, 0))) + 1 + >>> len(list(watching(True, 1, 1))) + 2 + >>> len(list(watching(True, None, 0))) + 1 + """ + + if w and not watch: + watch = 2 + if watch and clear: + click.clear() + yield 0 + + if max_count is not None and max_count < 1: + return + + counter = 1 + while watch and counter <= (max_count or counter): + time.sleep(watch) + counter += 1 + if clear: + click.clear() + yield 0 + + +def get_all_members(cluster, role='master'): + if role == 'master': + if cluster.leader is not None and cluster.leader.name: + yield cluster.leader + return + + leader_name = (cluster.leader.member.name if cluster.leader else None) + for m in cluster.members: + if role == 'any' or role == 'replica' and m.name != leader_name: + yield m + + +def get_any_member(cluster, role='master', member=None): + members = get_all_members(cluster, role) + for m in members: + if member is None or m.name == member: + return m + + +def get_all_members_leader_first(cluster): + leader_name = cluster.leader.member.name if cluster.leader and cluster.leader.member.api_url else None + if leader_name: + yield cluster.leader.member + for member in cluster.members: + if member.api_url and member.name != leader_name: + yield member + + +def get_cursor(cluster, connect_parameters, role='master', member=None): + member = get_any_member(cluster, role=role, member=member) + if member is None: + return None + + params = member.conn_kwargs(connect_parameters) + params.update({'fallback_application_name': 'Patroni ctl', 'connect_timeout': '5'}) + if 'database' in connect_parameters: + params['database'] = connect_parameters['database'] + else: + params.pop('database') + + import psycopg2 + conn = psycopg2.connect(**params) + conn.autocommit = True + cursor = conn.cursor() + if role == 'any': + return cursor + + cursor.execute('SELECT pg_catalog.pg_is_in_recovery()') + in_recovery = cursor.fetchone()[0] + + if in_recovery and role == 'replica' or not in_recovery and role == 'master': + return cursor + + conn.close() + + return None + + +def get_members(cluster, cluster_name, member_names, role, force, action, ask_confirmation=True): + candidates = {m.name: m for m in cluster.members} + + if not force or role: + if not member_names and not candidates: + raise PatroniCtlException('{0} cluster doesn\'t have any members'.format(cluster_name)) + output_members(cluster, cluster_name) + + if role: + role_names = [m.name for m in get_all_members(cluster, role)] + if member_names: + member_names = list(set(member_names) & set(role_names)) + if not member_names: + raise PatroniCtlException('No {0} among provided members'.format(role)) + else: + member_names = role_names + + if not member_names and not force: + member_names = [click.prompt('Which member do you want to {0} [{1}]?'.format(action, + ', '.join(candidates.keys())), type=str, default='')] + + for member_name in member_names: + if member_name not in candidates: + raise PatroniCtlException('{0} is not a member of cluster'.format(member_name)) + + members = [candidates[n] for n in member_names] + if ask_confirmation: + confirm_members_action(members, force, action) + return members + + +def confirm_members_action(members, force, action, scheduled_at=None): + if scheduled_at: + if not force: + confirm = click.confirm('Are you sure you want to schedule {0} of members {1} at {2}?' + .format(action, ', '.join([m.name for m in members]), scheduled_at)) + if not confirm: + raise PatroniCtlException('Aborted scheduled {0}'.format(action)) + else: + if not force: + confirm = click.confirm('Are you sure you want to {0} members {1}?' + .format(action, ', '.join([m.name for m in members]))) + if not confirm: + raise PatroniCtlException('Aborted {0}'.format(action)) + + +@ctl.command('dsn', help='Generate a dsn for the provided member, defaults to a dsn of the master') +@click.option('--role', '-r', help='Give a dsn of any member with this role', type=click.Choice(['master', 'replica', + 'any']), default=None) +@click.option('--member', '-m', help='Generate a dsn for this member', type=str) +@arg_cluster_name +@click.pass_obj +def dsn(obj, cluster_name, role, member): + if role is not None and member is not None: + raise PatroniCtlException('--role and --member are mutually exclusive options') + if member is None and role is None: + role = 'master' + + cluster = get_dcs(obj, cluster_name).get_cluster() + m = get_any_member(cluster, role=role, member=member) + if m is None: + raise PatroniCtlException('Can not find a suitable member') + + params = m.conn_kwargs() + click.echo('host={host} port={port}'.format(**params)) + + +@ctl.command('query', help='Query a Patroni PostgreSQL member') +@arg_cluster_name +@click.option('--format', 'fmt', help='Output format (pretty, tsv, json, yaml)', default='tsv') +@click.option('--file', '-f', 'p_file', help='Execute the SQL commands from this file', type=click.File('rb')) +@click.option('--password', help='force password prompt', is_flag=True) +@click.option('-U', '--username', help='database user name', type=str) +@option_watch +@option_watchrefresh +@click.option('--role', '-r', help='The role of the query', type=click.Choice(['master', 'replica', 'any']), + default=None) +@click.option('--member', '-m', help='Query a specific member', type=str) +@click.option('--delimiter', help='The column delimiter', default='\t') +@click.option('--command', '-c', help='The SQL commands to execute') +@click.option('-d', '--dbname', help='database name to connect to', type=str) +@click.pass_obj +def query( + obj, + cluster_name, + role, + member, + w, + watch, + delimiter, + command, + p_file, + password, + username, + dbname, + fmt='tsv', +): + if role is not None and member is not None: + raise PatroniCtlException('--role and --member are mutually exclusive options') + if member is None and role is None: + role = 'master' + + if p_file is not None and command is not None: + raise PatroniCtlException('--file and --command are mutually exclusive options') + + if p_file is None and command is None: + raise PatroniCtlException('You need to specify either --command or --file') + + connect_parameters = {} + if username: + connect_parameters['username'] = username + if password: + connect_parameters['password'] = click.prompt('Password', hide_input=True, type=str) + if dbname: + connect_parameters['database'] = dbname + + if p_file is not None: + command = p_file.read() + + dcs = get_dcs(obj, cluster_name) + + cursor = None + for _ in watching(w, watch, clear=False): + if cursor is None: + cluster = dcs.get_cluster() + + output, header = query_member(cluster, cursor, member, role, command, connect_parameters) + print_output(header, output, fmt=fmt, delimiter=delimiter) + + +def query_member(cluster, cursor, member, role, command, connect_parameters): + import psycopg2 + try: + if cursor is None: + cursor = get_cursor(cluster, connect_parameters, role=role, member=member) + + if cursor is None: + if role is None: + message = 'No connection to member {0} is available'.format(member) + else: + message = 'No connection to role={0} is available'.format(role) + logging.debug(message) + return [[timestamp(0), message]], None + + cursor.execute(command) + return cursor.fetchall(), [d.name for d in cursor.description] + except (psycopg2.OperationalError, psycopg2.DatabaseError) as oe: + logging.debug(oe) + if cursor is not None and not cursor.connection.closed: + cursor.connection.close() + message = oe.pgcode or oe.pgerror or str(oe) + message = message.replace('\n', ' ') + return [[timestamp(0), 'ERROR, SQLSTATE: {0}'.format(message)]], None + + +@ctl.command('remove', help='Remove cluster from DCS') +@click.argument('cluster_name') +@option_format +@click.pass_obj +def remove(obj, cluster_name, fmt): + dcs = get_dcs(obj, cluster_name) + cluster = dcs.get_cluster() + + output_members(cluster, cluster_name, fmt=fmt) + + confirm = click.prompt('Please confirm the cluster name to remove', type=str) + if confirm != cluster_name: + raise PatroniCtlException('Cluster names specified do not match') + + message = 'Yes I am aware' + confirm = \ + click.prompt('You are about to remove all information in DCS for {0}, please type: "{1}"'.format(cluster_name, + message), type=str) + if message != confirm: + raise PatroniCtlException('You did not exactly type "{0}"'.format(message)) + + if cluster.leader and cluster.leader.name: + confirm = click.prompt('This cluster currently is healthy. Please specify the master name to continue') + if confirm != cluster.leader.name: + raise PatroniCtlException('You did not specify the current master of the cluster') + + dcs.delete_cluster() + + +def check_response(response, member_name, action_name, silent_success=False): + if response.status >= 400: + click.echo('Failed: {0} for member {1}, status code={2}, ({3})'.format( + action_name, member_name, response.status, response.data.decode('utf-8') + )) + return False + elif not silent_success: + click.echo('Success: {0} for member {1}'.format(action_name, member_name)) + return True + + +def parse_scheduled(scheduled): + if (scheduled or 'now') != 'now': + try: + scheduled_at = dateutil.parser.parse(scheduled) + if scheduled_at.tzinfo is None: + scheduled_at = scheduled_at.replace(tzinfo=dateutil.tz.tzlocal()) + except (ValueError, TypeError): + message = 'Unable to parse scheduled timestamp ({0}). It should be in an unambiguous format (e.g. ISO 8601)' + raise PatroniCtlException(message.format(scheduled)) + return scheduled_at + + return None + + +@ctl.command('reload', help='Reload cluster member configuration') +@click.argument('cluster_name') +@click.argument('member_names', nargs=-1) +@click.option('--role', '-r', help='Reload only members with this role', default='any', + type=click.Choice(['master', 'replica', 'any'])) +@option_force +@click.pass_obj +def reload(obj, cluster_name, member_names, force, role): + cluster = get_dcs(obj, cluster_name).get_cluster() + + members = get_members(cluster, cluster_name, member_names, role, force, 'reload') + + for member in members: + r = request_patroni(member, 'post', 'reload') + if r.status == 200: + click.echo('No changes to apply on member {0}'.format(member.name)) + elif r.status == 202: + click.echo('Reload request received for member {0} and will be processed within {1} seconds'.format( + member.name, cluster.config.data.get('loop_wait')) + ) + else: + click.echo('Failed: reload for member {0}, status code={1}, ({2})'.format( + member.name, r.status, r.data.decode('utf-8')) + ) + + +@ctl.command('restart', help='Restart cluster member') +@click.argument('cluster_name') +@click.argument('member_names', nargs=-1) +@click.option('--role', '-r', help='Restart only members with this role', default='any', + type=click.Choice(['master', 'replica', 'any'])) +@click.option('--any', 'p_any', help='Restart a single member only', is_flag=True) +@click.option('--scheduled', help='Timestamp of a scheduled restart in unambiguous format (e.g. ISO 8601)', + default=None) +@click.option('--pg-version', 'version', help='Restart if the PostgreSQL version is less than provided (e.g. 9.5.2)', + default=None) +@click.option('--pending', help='Restart if pending', is_flag=True) +@click.option('--timeout', + help='Return error and fail over if necessary when restarting takes longer than this.') +@option_force +@click.pass_obj +def restart(obj, cluster_name, member_names, force, role, p_any, scheduled, version, pending, timeout): + cluster = get_dcs(obj, cluster_name).get_cluster() + + members = get_members(cluster, cluster_name, member_names, role, force, 'restart', False) + if scheduled is None and not force: + next_hour = (datetime.datetime.now() + datetime.timedelta(hours=1)).strftime('%Y-%m-%dT%H:%M') + scheduled = click.prompt('When should the restart take place (e.g. ' + next_hour + ') ', + type=str, default='now') + + scheduled_at = parse_scheduled(scheduled) + confirm_members_action(members, force, 'restart', scheduled_at) + + if p_any: + random.shuffle(members) + members = members[:1] + + if version is None and not force: + version = click.prompt('Restart if the PostgreSQL version is less than provided (e.g. 9.5.2) ', + type=str, default='') + + content = {} + if pending: + content['restart_pending'] = True + + if version: + try: + postgres_version_to_int(version) + except PatroniException as e: + raise PatroniCtlException(e.value) + + content['postgres_version'] = version + + if scheduled_at: + if cluster.is_paused(): + raise PatroniCtlException("Can't schedule restart in the paused state") + content['schedule'] = scheduled_at.isoformat() + + if timeout is not None: + content['timeout'] = timeout + + for member in members: + if 'schedule' in content: + if force and member.data.get('scheduled_restart'): + r = request_patroni(member, 'delete', 'restart') + check_response(r, member.name, 'flush scheduled restart', True) + + r = request_patroni(member, 'post', 'restart', content) + if r.status == 200: + click.echo('Success: restart on member {0}'.format(member.name)) + elif r.status == 202: + click.echo('Success: restart scheduled on member {0}'.format(member.name)) + elif r.status == 409: + click.echo('Failed: another restart is already scheduled on member {0}'.format(member.name)) + else: + click.echo('Failed: restart for member {0}, status code={1}, ({2})'.format( + member.name, r.status, r.data.decode('utf-8')) + ) + + +@ctl.command('reinit', help='Reinitialize cluster member') +@click.argument('cluster_name') +@click.argument('member_names', nargs=-1) +@option_force +@click.option('--wait', help='Wait until reinitialization completes', is_flag=True) +@click.pass_obj +def reinit(obj, cluster_name, member_names, force, wait): + cluster = get_dcs(obj, cluster_name).get_cluster() + members = get_members(cluster, cluster_name, member_names, None, force, 'reinitialize') + + wait_on_members = [] + for member in members: + body = {'force': force} + while True: + r = request_patroni(member, 'post', 'reinitialize', body) + started = check_response(r, member.name, 'reinitialize') + if not started and r.data.endswith(b' already in progress') \ + and not force and click.confirm('Do you want to cancel it and reinitialize anyway?'): + body['force'] = True + continue + break + if started and wait: + wait_on_members.append(member) + + last_display = [] + while wait_on_members: + if wait_on_members != last_display: + click.echo('Waiting for reinitialize to complete on: {0}'.format( + ", ".join(member.name for member in wait_on_members)) + ) + last_display[:] = wait_on_members + time.sleep(2) + for member in wait_on_members: + data = json.loads(request_patroni(member, 'get', 'patroni').data.decode('utf-8')) + if data.get('state') != 'creating replica': + click.echo('Reinitialize is completed on: {0}'.format(member.name)) + wait_on_members.remove(member) + + +def _do_failover_or_switchover(obj, action, cluster_name, master, candidate, force, scheduled=None): + """ + We want to trigger a failover or switchover for the specified cluster name. + + We verify that the cluster name, master name and candidate name are correct. + If so, we trigger an action and keep the client up to date. + """ + + dcs = get_dcs(obj, cluster_name) + cluster = dcs.get_cluster() + + if action == 'switchover' and (cluster.leader is None or not cluster.leader.name): + raise PatroniCtlException('This cluster has no master') + + if master is None: + if force or action == 'failover': + master = cluster.leader and cluster.leader.name + else: + master = click.prompt('Master', type=str, default=cluster.leader.member.name) + + if master is not None and cluster.leader and cluster.leader.member.name != master: + raise PatroniCtlException('Member {0} is not the leader of cluster {1}'.format(master, cluster_name)) + + # excluding members with nofailover tag + candidate_names = [str(m.name) for m in cluster.members if m.name != master and not m.nofailover] + # We sort the names for consistent output to the client + candidate_names.sort() + + if not candidate_names: + raise PatroniCtlException('No candidates found to {0} to'.format(action)) + + if candidate is None and not force: + candidate = click.prompt('Candidate ' + str(candidate_names), type=str, default='') + + if action == 'failover' and not candidate: + raise PatroniCtlException('Failover could be performed only to a specific candidate') + + if candidate == master: + raise PatroniCtlException(action.title() + ' target and source are the same.') + + if candidate and candidate not in candidate_names: + raise PatroniCtlException('Member {0} does not exist in cluster {1}'.format(candidate, cluster_name)) + + scheduled_at_str = None + scheduled_at = None + + if action == 'switchover': + if scheduled is None and not force: + next_hour = (datetime.datetime.now() + datetime.timedelta(hours=1)).strftime('%Y-%m-%dT%H:%M') + scheduled = click.prompt('When should the switchover take place (e.g. ' + next_hour + ' ) ', + type=str, default='now') + + scheduled_at = parse_scheduled(scheduled) + if scheduled_at: + if cluster.is_paused(): + raise PatroniCtlException("Can't schedule switchover in the paused state") + scheduled_at_str = scheduled_at.isoformat() + + failover_value = {'leader': master, 'candidate': candidate, 'scheduled_at': scheduled_at_str} + + logging.debug(failover_value) + + # By now we have established that the leader exists and the candidate exists + click.echo('Current cluster topology') + output_members(dcs.get_cluster(), cluster_name) + + if not force: + demote_msg = ', demoting current master ' + master if master else '' + if scheduled_at_str: + if not click.confirm('Are you sure you want to schedule {0} of cluster {1} at {2}{3}?' + .format(action, cluster_name, scheduled_at_str, demote_msg)): + raise PatroniCtlException('Aborting scheduled ' + action) + else: + if not click.confirm('Are you sure you want to {0} cluster {1}{2}?' + .format(action, cluster_name, demote_msg)): + raise PatroniCtlException('Aborting ' + action) + + r = None + try: + member = cluster.leader.member if cluster.leader else cluster.get_member(candidate, False) + + r = request_patroni(member, 'post', action, failover_value) + + # probably old patroni, which doesn't support switchover yet + if r.status == 501 and action == 'switchover' and b'Server does not support this operation' in r.data: + r = request_patroni(member, 'post', 'failover', failover_value) + + if r.status in (200, 202): + logging.debug(r) + cluster = dcs.get_cluster() + logging.debug(cluster) + click.echo('{0} {1}'.format(timestamp(), r.data.decode('utf-8'))) + else: + click.echo('{0} failed, details: {1}, {2}'.format(action.title(), r.status, r.data.decode('utf-8'))) + return + except Exception: + logging.exception(r) + logging.warning('Failing over to DCS') + click.echo('{0} Could not {1} using Patroni api, falling back to DCS'.format(timestamp(), action)) + dcs.manual_failover(master, candidate, scheduled_at=scheduled_at) + + output_members(cluster, cluster_name) + + +@ctl.command('failover', help='Failover to a replica') +@arg_cluster_name +@click.option('--master', help='The name of the current master', default=None) +@click.option('--candidate', help='The name of the candidate', default=None) +@option_force +@click.pass_obj +def failover(obj, cluster_name, master, candidate, force): + action = 'switchover' if master else 'failover' + _do_failover_or_switchover(obj, action, cluster_name, master, candidate, force) + + +@ctl.command('switchover', help='Switchover to a replica') +@arg_cluster_name +@click.option('--master', help='The name of the current master', default=None) +@click.option('--candidate', help='The name of the candidate', default=None) +@click.option('--scheduled', help='Timestamp of a scheduled switchover in unambiguous format (e.g. ISO 8601)', + default=None) +@option_force +@click.pass_obj +def switchover(obj, cluster_name, master, candidate, force, scheduled): + _do_failover_or_switchover(obj, 'switchover', cluster_name, master, candidate, force, scheduled) + + +def generate_topology(level, member, topology): + members = topology.get(member['name'], []) + + if level > 0: + member['name'] = '{0}+ {1}'.format((' ' * (level - 1) * 2), member['name']) + + if member['name']: + yield member + + for member in members: + for member in generate_topology(level + 1, member, topology): + yield member + + +def topology_sort(members): + topology = defaultdict(list) + leader = next((m for m in members if m['role'].endswith('leader')), {'name': None}) + replicas = set(member['name'] for member in members if not member['role'].endswith('leader')) + for member in members: + if not member['role'].endswith('leader'): + parent = member.get('tags', {}).get('replicatefrom') + parent = parent if parent and parent != member['name'] and parent in replicas else leader['name'] + topology[parent].append(member) + for member in generate_topology(0, leader, topology): + yield member + + +def output_members(cluster, name, extended=False, fmt='pretty'): + rows = [] + logging.debug(cluster) + initialize = {None: 'uninitialized', '': 'initializing'}.get(cluster.initialize, cluster.initialize) + cluster = cluster_as_json(cluster) + + columns = ['Cluster', 'Member', 'Host', 'Role', 'State', 'TL', 'Lag in MB'] + for c in ('Pending restart', 'Scheduled restart', 'Tags'): + if extended or any(m.get(c.lower().replace(' ', '_')) for m in cluster['members']): + columns.append(c) + + # Show Host as 'host:port' if somebody is running on non-standard port or two nodes are running on the same host + members = [m for m in cluster['members'] if 'host' in m] + append_port = any('port' in m and m['port'] != 5432 for m in members) or\ + len(set(m['host'] for m in members)) < len(members) + + sort = topology_sort if fmt == 'topology' else iter + for m in sort(cluster['members']): + logging.debug(m) + + lag = m.get('lag', '') + m.update(cluster=name, member=m['name'], host=m.get('host', ''), tl=m.get('timeline', ''), + role=m['role'].replace('_', ' ').title(), + lag_in_mb=round(lag/1024/1024) if isinstance(lag, six.integer_types) else lag, + pending_restart='*' if m.get('pending_restart') else '') + + if append_port and m['host'] and m.get('port'): + m['host'] = ':'.join([m['host'], str(m['port'])]) + + if 'scheduled_restart' in m: + value = m['scheduled_restart']['schedule'] + if 'postgres_version' in m['scheduled_restart']: + value += ' if version < {0}'.format(m['scheduled_restart']['postgres_version']) + m['scheduled_restart'] = value + + rows.append([m.get(n.lower().replace(' ', '_'), '') for n in columns]) + + print_output(columns, rows, {'Lag in MB': 'r', 'TL': 'r'}, fmt, ' Cluster: {0} ({1}) '.format(name, initialize)) + + if fmt not in ('pretty', 'topology'): # Omit service info when using machine-readable formats + return + + service_info = [] + if cluster.get('pause'): + service_info.append('Maintenance mode: on') + + if 'scheduled_switchover' in cluster: + info = 'Switchover scheduled at: ' + cluster['scheduled_switchover']['at'] + for name in ('from', 'to'): + if name in cluster['scheduled_switchover']: + info += '\n{0:>24}: {1}'.format(name, cluster['scheduled_switchover'][name]) + service_info.append(info) + + if service_info: + click.echo(' ' + '\n '.join(service_info)) + + +@ctl.command('list', help='List the Patroni members for a given Patroni') +@click.argument('cluster_names', nargs=-1) +@click.option('--extended', '-e', help='Show some extra information', is_flag=True) +@click.option('--timestamp', '-t', 'ts', help='Print timestamp', is_flag=True) +@option_format +@option_watch +@option_watchrefresh +@click.pass_obj +def members(obj, cluster_names, fmt, watch, w, extended, ts): + if not cluster_names: + if 'scope' in obj: + cluster_names = [obj['scope']] + if not cluster_names: + return logging.warning('Listing members: No cluster names were provided') + + for cluster_name in cluster_names: + dcs = get_dcs(obj, cluster_name) + + for _ in watching(w, watch): + if ts: + click.echo(timestamp(0)) + + cluster = dcs.get_cluster() + output_members(cluster, cluster_name, extended, fmt) + + +@ctl.command('topology', help='Prints ASCII topology for given cluster') +@click.argument('cluster_names', nargs=-1) +@option_watch +@option_watchrefresh +@click.pass_obj +@click.pass_context +def topology(ctx, obj, cluster_names, watch, w): + ctx.forward(members, fmt='topology') + + +def timestamp(precision=6): + return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:precision - 7] + + +@ctl.command('configure', help='Create configuration file') +@click.option('--config-file', '-c', help='Configuration file', prompt='Configuration file', default=CONFIG_FILE_PATH) +@click.option('--dcs', '-d', help='The DCS connect url', prompt='DCS connect url', default='etcd://localhost:2379') +@click.option('--namespace', '-n', help='The namespace', prompt='Namespace', default='/service/') +def configure(config_file, dcs, namespace): + store_config({'dcs_api': str(dcs), 'namespace': str(namespace)}, config_file) + + +def touch_member(config, dcs): + ''' Rip-off of the ha.touch_member without inter-class dependencies ''' + p = Postgresql(config['postgresql']) + p.set_state('running') + p.set_role('master') + + def restapi_connection_string(config): + protocol = 'https' if config.get('certfile') else 'http' + connect_address = config.get('connect_address') + listen = config['listen'] + return '{0}://{1}/patroni'.format(protocol, connect_address or listen) + + data = { + 'conn_url': p.connection_string, + 'api_url': restapi_connection_string(config['restapi']), + 'state': p.state, + 'role': p.role + } + + return dcs.touch_member(data, permanent=True) + + +def set_defaults(config, cluster_name): + """fill-in some basic configuration parameters if config file is not set """ + config['postgresql'].setdefault('name', cluster_name) + config['postgresql'].setdefault('scope', cluster_name) + config['postgresql'].setdefault('listen', '127.0.0.1') + config['postgresql']['authentication'] = {'replication': None} + config['restapi']['listen'] = ':' in config['restapi']['listen'] and config['restapi']['listen'] or '127.0.0.1:8008' + + +@ctl.command('scaffold', help='Create a structure for the cluster in DCS') +@click.argument('cluster_name') +@click.option('--sysid', '-s', help='System ID of the cluster to put into the initialize key', default="") +@click.pass_obj +def scaffold(obj, cluster_name, sysid): + dcs = get_dcs(obj, cluster_name) + cluster = dcs.get_cluster() + if cluster and cluster.initialize is not None: + raise PatroniCtlException("This cluster is already initialized") + + if not dcs.initialize(create_new=True, sysid=sysid): + # initialize key already exists, don't touch this cluster + raise PatroniCtlException("Initialize key for cluster {0} already exists".format(cluster_name)) + + set_defaults(obj, cluster_name) + + # make sure the leader keys will never expire + if not (touch_member(obj, dcs) and dcs.attempt_to_acquire_leader(permanent=True)): + # we did initialize this cluster, but failed to write the leader or member keys, wipe it down completely. + dcs.delete_cluster() + raise PatroniCtlException("Unable to install permanent leader for cluster {0}".format(cluster_name)) + click.echo("Cluster {0} has been created successfully".format(cluster_name)) + + +@ctl.command('flush', help='Discard scheduled events') +@click.argument('cluster_name') +@click.argument('member_names', nargs=-1) +@click.argument('target', type=click.Choice(['restart', 'switchover'])) +@click.option('--role', '-r', help='Flush only members with this role', default='any', + type=click.Choice(['master', 'replica', 'any'])) +@option_force +@click.pass_obj +def flush(obj, cluster_name, member_names, force, role, target): + dcs = get_dcs(obj, cluster_name) + cluster = dcs.get_cluster() + + if target == 'restart': + for member in get_members(cluster, cluster_name, member_names, role, force, 'flush'): + if member.data.get('scheduled_restart'): + r = request_patroni(member, 'delete', 'restart') + check_response(r, member.name, 'flush scheduled restart') + else: + click.echo('No scheduled restart for member {0}'.format(member.name)) + elif target == 'switchover': + failover = cluster.failover + if not failover or not failover.scheduled_at: + return click.echo('No pending scheduled switchover') + for member in get_all_members_leader_first(cluster): + try: + r = request_patroni(member, 'delete', 'switchover') + if r.status in (200, 404): + prefix = 'Success' if r.status == 200 else 'Failed' + return click.echo('{0}: {1}'.format(prefix, r.data.decode('utf-8'))) + except Exception as err: + logging.warning(str(err)) + logging.warning('Member %s is not accessible', member.name) + + click.echo('Failed: member={0}, status_code={1}, ({2})'.format( + member.name, r.status, r.data.decode('utf-8'))) + + logging.warning('Failing over to DCS') + click.echo('{0} Could not find any accessible member of cluster {1}'.format(timestamp(), cluster_name)) + dcs.manual_failover('', '', index=failover.index) + + +def wait_until_pause_is_applied(dcs, paused, old_cluster): + click.echo("'{0}' request sent, waiting until it is recognized by all nodes".format(paused and 'pause' or 'resume')) + old = {m.name: m.index for m in old_cluster.members if m.api_url} + loop_wait = old_cluster.config.data.get('loop_wait', dcs.loop_wait) + + for _ in polling_loop(loop_wait + 1): + cluster = dcs.get_cluster() + if all(m.data.get('pause', False) == paused for m in cluster.members if m.name in old): + break + else: + remaining = [m.name for m in cluster.members if m.data.get('pause', False) != paused + and m.name in old and old[m.name] != m.index] + if remaining: + return click.echo("{0} members didn't recognized pause state after {1} seconds" + .format(', '.join(remaining), loop_wait)) + return click.echo('Success: cluster management is {0}'.format(paused and 'paused' or 'resumed')) + + +def toggle_pause(config, cluster_name, paused, wait): + dcs = get_dcs(config, cluster_name) + cluster = dcs.get_cluster() + if cluster.is_paused() == paused: + raise PatroniCtlException('Cluster is {0} paused'.format(paused and 'already' or 'not')) + + for member in get_all_members_leader_first(cluster): + try: + r = request_patroni(member, 'patch', 'config', {'pause': paused or None}) + except Exception as err: + logging.warning(str(err)) + logging.warning('Member %s is not accessible', member.name) + continue + + if r.status == 200: + if wait: + wait_until_pause_is_applied(dcs, paused, cluster) + else: + click.echo('Success: cluster management is {0}'.format(paused and 'paused' or 'resumed')) + else: + click.echo('Failed: {0} cluster management status code={1}, ({2})'.format( + paused and 'pause' or 'resume', r.status, r.data.decode('utf-8'))) + break + else: + raise PatroniCtlException('Can not find accessible cluster member') + + +@ctl.command('pause', help='Disable auto failover') +@arg_cluster_name +@click.pass_obj +@click.option('--wait', help='Wait until pause is applied on all nodes', is_flag=True) +def pause(obj, cluster_name, wait): + return toggle_pause(obj, cluster_name, True, wait) + + +@ctl.command('resume', help='Resume auto failover') +@arg_cluster_name +@click.option('--wait', help='Wait until pause is cleared on all nodes', is_flag=True) +@click.pass_obj +def resume(obj, cluster_name, wait): + return toggle_pause(obj, cluster_name, False, wait) + + +@contextmanager +def temporary_file(contents, suffix='', prefix='tmp'): + """Creates a temporary file with specified contents that persists for the context. + + :param contents: binary string that will be written to the file. + :param prefix: will be prefixed to the filename. + :param suffix: will be appended to the filename. + :returns path of the created file. + """ + tmp = tempfile.NamedTemporaryFile(suffix=suffix, prefix=prefix, delete=False) + with tmp: + tmp.write(contents) + + try: + yield tmp.name + finally: + os.unlink(tmp.name) + + +def show_diff(before_editing, after_editing): + """Shows a diff between two strings. + + If the output is to a tty the diff will be colored. Inputs are expected to be unicode strings. + """ + def listify(string): + return [line + '\n' for line in string.rstrip('\n').split('\n')] + + unified_diff = difflib.unified_diff(listify(before_editing), listify(after_editing)) + + if sys.stdout.isatty(): + buf = io.StringIO() + for line in unified_diff: + # Force cast to unicode as difflib on Python 2.7 returns a mix of unicode and str. + buf.write(six.text_type(line)) + buf.seek(0) + + class opts: + side_by_side = False + width = 80 + tab_width = 8 + wrap = True + if find_executable('less'): + pager = None + else: + pager = 'more.com' if sys.platform == 'win32' else 'more' + pager_options = None + + markup_to_pager(PatchStream(buf), opts) + else: + for line in unified_diff: + click.echo(line.rstrip('\n')) + + +def format_config_for_editing(data, default_flow_style=False): + """Formats configuration as YAML for human consumption. + + :param data: configuration as nested dictionaries + :returns unicode YAML of the configuration""" + return yaml.safe_dump(data, default_flow_style=default_flow_style, encoding=None, allow_unicode=True, width=200) + + +def apply_config_changes(before_editing, data, kvpairs): + """Applies config changes specified as a list of key-value pairs. + + Keys are interpreted as dotted paths into the configuration data structure. Except for paths beginning with + `postgresql.parameters` where rest of the path is used directly to allow for PostgreSQL GUCs containing dots. + Values are interpreted as YAML values. + + :param before_editing: human representation before editing + :param data: configuration datastructure + :param kvpairs: list of strings containing key value pairs separated by = + :returns tuple of human readable and parsed datastructure after changes + """ + changed_data = copy.deepcopy(data) + + def set_path_value(config, path, value, prefix=()): + # Postgresql GUCs can't be nested, but can contain dots so we re-flatten the structure for this case + if prefix == ('postgresql', 'parameters'): + path = ['.'.join(path)] + + key = path[0] + if len(path) == 1: + if value is None: + config.pop(key, None) + else: + config[key] = value + else: + if not isinstance(config.get(key), dict): + config[key] = {} + set_path_value(config[key], path[1:], value, prefix + (key,)) + if config[key] == {}: + del config[key] + + for pair in kvpairs: + if not pair or "=" not in pair: + raise PatroniCtlException("Invalid parameter setting {0}".format(pair)) + key_path, value = pair.split("=", 1) + set_path_value(changed_data, key_path.strip().split("."), yaml.safe_load(value)) + + return format_config_for_editing(changed_data), changed_data + + +def apply_yaml_file(data, filename): + """Applies changes from a YAML file to configuration + + :param data: configuration datastructure + :param filename: name of the YAML file, - is taken to mean standard input + :returns tuple of human readable and parsed datastructure after changes + """ + changed_data = copy.deepcopy(data) + + if filename == '-': + new_options = yaml.safe_load(sys.stdin) + else: + with open(filename) as fd: + new_options = yaml.safe_load(fd) + + patch_config(changed_data, new_options) + + return format_config_for_editing(changed_data), changed_data + + +def invoke_editor(before_editing, cluster_name): + """Starts editor command to edit configuration in human readable format + + :param before_editing: human representation before editing + :returns tuple of human readable and parsed datastructure after changes + """ + + editor_cmd = os.environ.get('EDITOR') + if not editor_cmd: + for editor in ('editor', 'vi'): + editor_cmd = find_executable(editor) + if editor_cmd: + logging.debug('Setting fallback editor_cmd=%s', editor) + break + if not editor_cmd: + raise PatroniCtlException('EDITOR environment variable is not set. editor or vi are not available') + + with temporary_file(contents=before_editing.encode('utf-8'), + suffix='.yaml', + prefix='{0}-config-'.format(cluster_name)) as tmpfile: + ret = subprocess.call([editor_cmd, tmpfile]) + if ret: + raise PatroniCtlException("Editor exited with return code {0}".format(ret)) + + with codecs.open(tmpfile, encoding='utf-8') as fd: + after_editing = fd.read() + + return after_editing, yaml.safe_load(after_editing) + + +@ctl.command('edit-config', help="Edit cluster configuration") +@arg_cluster_name +@click.option('--quiet', '-q', is_flag=True, help='Do not show changes') +@click.option('--set', '-s', 'kvpairs', multiple=True, + help='Set specific configuration value. Can be specified multiple times') +@click.option('--pg', '-p', 'pgkvpairs', multiple=True, + help='Set specific PostgreSQL parameter value. Shorthand for -s postgresql.parameters. ' + 'Can be specified multiple times') +@click.option('--apply', 'apply_filename', help='Apply configuration from file. Use - for stdin.') +@click.option('--replace', 'replace_filename', help='Apply configuration from file, replacing existing configuration.' + ' Use - for stdin.') +@option_force +@click.pass_obj +def edit_config(obj, cluster_name, force, quiet, kvpairs, pgkvpairs, apply_filename, replace_filename): + dcs = get_dcs(obj, cluster_name) + cluster = dcs.get_cluster() + + before_editing = format_config_for_editing(cluster.config.data) + + after_editing = None # Serves as a flag if any changes were requested + changed_data = cluster.config.data + + if replace_filename: + after_editing, changed_data = apply_yaml_file({}, replace_filename) + + if apply_filename: + after_editing, changed_data = apply_yaml_file(changed_data, apply_filename) + + if kvpairs or pgkvpairs: + all_pairs = list(kvpairs) + ['postgresql.parameters.'+v.lstrip() for v in pgkvpairs] + after_editing, changed_data = apply_config_changes(before_editing, changed_data, all_pairs) + + # If no changes were specified on the command line invoke editor + if after_editing is None: + after_editing, changed_data = invoke_editor(before_editing, cluster_name) + + if cluster.config.data == changed_data: + if not quiet: + click.echo("Not changed") + return + + if not quiet: + show_diff(before_editing, after_editing) + + if (apply_filename == '-' or replace_filename == '-') and not force: + click.echo("Use --force option to apply changes") + return + + if force or click.confirm('Apply these changes?'): + if not dcs.set_config_value(json.dumps(changed_data), cluster.config.index): + raise PatroniCtlException("Config modification aborted due to concurrent changes") + click.echo("Configuration changed") + + +@ctl.command('show-config', help="Show cluster configuration") +@arg_cluster_name +@click.pass_obj +def show_config(obj, cluster_name): + cluster = get_dcs(obj, cluster_name).get_cluster() + + click.echo(format_config_for_editing(cluster.config.data)) + + +@ctl.command('version', help='Output version of patronictl command or a running Patroni instance') +@click.argument('cluster_name', required=False) +@click.argument('member_names', nargs=-1) +@click.pass_obj +def version(obj, cluster_name, member_names): + click.echo("patronictl version {0}".format(__version__)) + + if not cluster_name: + return + + click.echo("") + cluster = get_dcs(obj, cluster_name).get_cluster() + for m in cluster.members: + if m.api_url: + if not member_names or m.name in member_names: + try: + response = request_patroni(m) + data = json.loads(response.data.decode('utf-8')) + version = data.get('patroni', {}).get('version') + pg_version = data.get('server_version') + pg_version_str = " PostgreSQL {0}".format(format_pg_version(pg_version)) if pg_version else "" + click.echo("{0}: Patroni {1}{2}".format(m.name, version, pg_version_str)) + except Exception as e: + click.echo("{0}: failed to get version: {1}".format(m.name, e)) + + +@ctl.command('history', help="Show the history of failovers/switchovers") +@arg_cluster_name +@option_format +@click.pass_obj +def history(obj, cluster_name, fmt): + cluster = get_dcs(obj, cluster_name).get_cluster() + history = cluster.history and cluster.history.lines or [] + for line in history: + if len(line) < 4: + line.append('') + print_output(['TL', 'LSN', 'Reason', 'Timestamp'], history, {'TL': 'r', 'LSN': 'r'}, fmt) + + +def format_pg_version(version): + if version < 100000: + return "{0}.{1}.{2}".format(version // 10000, version // 100 % 100, version % 100) + else: + return "{0}.{1}".format(version // 10000, version % 100) diff --git a/patroni-for-openGauss/daemon.py b/patroni-for-openGauss/daemon.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb999903fad324d604de9b0fdafc696f7ec1440 --- /dev/null +++ b/patroni-for-openGauss/daemon.py @@ -0,0 +1,104 @@ +import abc +import os +import signal +import six +import sys + +from threading import Lock + + +@six.add_metaclass(abc.ABCMeta) +class AbstractPatroniDaemon(object): + + def __init__(self, config): + from patroni.log import PatroniLogger + + self.setup_signal_handlers() + + self.logger = PatroniLogger() + self.config = config + AbstractPatroniDaemon.reload_config(self, local=True) + + def sighup_handler(self, *args): + self._received_sighup = True + + def sigterm_handler(self, *args): + with self._sigterm_lock: + if not self._received_sigterm: + self._received_sigterm = True + sys.exit() + + def setup_signal_handlers(self): + self._received_sighup = False + self._sigterm_lock = Lock() + self._received_sigterm = False + if os.name != 'nt': + signal.signal(signal.SIGHUP, self.sighup_handler) + signal.signal(signal.SIGTERM, self.sigterm_handler) + + @property + def received_sigterm(self): + with self._sigterm_lock: + return self._received_sigterm + + def reload_config(self, sighup=False, local=False): + if local: + self.logger.reload_config(self.config.get('log', {})) + + @abc.abstractmethod + def _run_cycle(self): + """_run_cycle""" + + def run(self): + self.logger.start() + while not self.received_sigterm: + if self._received_sighup: + self._received_sighup = False + self.reload_config(True, self.config.reload_local_configuration()) + + self._run_cycle() + + @abc.abstractmethod + def _shutdown(self): + """_shutdown""" + + def shutdown(self): + with self._sigterm_lock: + self._received_sigterm = True + self._shutdown() + self.logger.shutdown() + + +def abstract_main(cls, validator=None): + import argparse + + from .config import Config, ConfigParseError + from .version import __version__ + + parser = argparse.ArgumentParser() + parser.add_argument('--version', action='version', version='%(prog)s {0}'.format(__version__)) + if validator: + parser.add_argument('--validate-config', action='store_true', help='Run config validator and exit') + parser.add_argument('configfile', nargs='?', default='', + help='Patroni may also read the configuration from the {0} environment variable' + .format(Config.PATRONI_CONFIG_VARIABLE)) + args = parser.parse_args() + try: + if validator and args.validate_config: + Config(args.configfile, validator=validator) + sys.exit() + + config = Config(args.configfile) + except ConfigParseError as e: + if e.value: + print(e.value) + parser.print_help() + sys.exit(1) + + controller = cls(config) + try: + controller.run() + except KeyboardInterrupt: + pass + finally: + controller.shutdown() diff --git a/patroni-for-openGauss/dcs/__init__.py b/patroni-for-openGauss/dcs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5503f3e9b334620106b2ebc16c1bc5f199362f9 --- /dev/null +++ b/patroni-for-openGauss/dcs/__init__.py @@ -0,0 +1,831 @@ +import abc +import dateutil +import importlib +import inspect +import json +import logging +import os +import pkgutil +import re +import six +import sys +import time + +from collections import defaultdict, namedtuple +from copy import deepcopy +from patroni.exceptions import PatroniFatalException +from patroni.utils import parse_bool, uri +from random import randint +from six.moves.urllib_parse import urlparse, urlunparse, parse_qsl +from threading import Event, Lock + +slot_name_re = re.compile('^[a-z0-9_]{1,63}$') +logger = logging.getLogger(__name__) + + +def slot_name_from_member_name(member_name): + """Translate member name to valid PostgreSQL slot name. + + PostgreSQL replication slot names must be valid PostgreSQL names. This function maps the wider space of + member names to valid PostgreSQL names. Names are lowercased, dashes and periods common in hostnames + are replaced with underscores, other characters are encoded as their unicode codepoint. Name is truncated + to 64 characters. Multiple different member names may map to a single slot name.""" + + def replace_char(match): + c = match.group(0) + return '_' if c in '-.' else "u{:04d}".format(ord(c)) + + slot_name = re.sub('[^a-z0-9_]', replace_char, member_name.lower()) + return slot_name[0:63] + + +def parse_connection_string(value): + """Original Governor stores connection strings for each cluster members if a following format: + postgres://{username}:{password}@{connect_address}/postgres + Since each of our patroni instances provides own REST API endpoint it's good to store this information + in DCS among with postgresql connection string. In order to not introduce new keys and be compatible with + original Governor we decided to extend original connection string in a following way: + postgres://{username}:{password}@{connect_address}/postgres?application_name={api_url} + This way original Governor could use such connection string as it is, because of feature of `libpq` library. + + This method is able to split connection string stored in DCS into two parts, `conn_url` and `api_url`""" + + scheme, netloc, path, params, query, fragment = urlparse(value) + conn_url = urlunparse((scheme, netloc, path, params, '', fragment)) + api_url = ([v for n, v in parse_qsl(query) if n == 'application_name'] or [None])[0] + return conn_url, api_url + + +def dcs_modules(): + """Get names of DCS modules, depending on execution environment. If being packaged with PyInstaller, + modules aren't discoverable dynamically by scanning source directory because `FrozenImporter` doesn't + implement `iter_modules` method. But it is still possible to find all potential DCS modules by + iterating through `toc`, which contains list of all "frozen" resources.""" + + dcs_dirname = os.path.dirname(__file__) + module_prefix = __package__ + '.' + + if getattr(sys, 'frozen', False): + toc = set() + for importer in pkgutil.iter_importers(dcs_dirname): + if hasattr(importer, 'toc'): + toc |= importer.toc + return [module for module in toc if module.startswith(module_prefix) and module.count('.') == 2] + else: + return [module_prefix + name for _, name, is_pkg in pkgutil.iter_modules([dcs_dirname]) if not is_pkg] + + +def get_dcs(config): + modules = dcs_modules() + + for module_name in modules: + name = module_name.split('.')[-1] + if name in config: # we will try to import only modules which have configuration section in the config file + try: + module = importlib.import_module(module_name) + for key, item in module.__dict__.items(): # iterate through the module content + # try to find implementation of AbstractDCS interface, class name must match with module_name + if key.lower() == name and inspect.isclass(item) and issubclass(item, AbstractDCS): + # propagate some parameters + config[name].update({p: config[p] for p in ('namespace', 'name', 'scope', 'loop_wait', + 'patronictl', 'ttl', 'retry_timeout') if p in config}) + return item(config[name]) + except ImportError: + logger.debug('Failed to import %s', module_name) + + available_implementations = [] + for module_name in modules: + name = module_name.split('.')[-1] + try: + module = importlib.import_module(module_name) + available_implementations.extend(name for key, item in module.__dict__.items() if key.lower() == name + and inspect.isclass(item) and issubclass(item, AbstractDCS)) + except ImportError: + logger.info('Failed to import %s', module_name) + raise PatroniFatalException("""Can not find suitable configuration of distributed configuration store +Available implementations: """ + ', '.join(sorted(set(available_implementations)))) + + +class Member(namedtuple('Member', 'index,name,session,data')): + + """Immutable object (namedtuple) which represents single member of PostgreSQL cluster. + Consists of the following fields: + :param index: modification index of a given member key in a Configuration Store + :param name: name of PostgreSQL cluster member + :param session: either session id or just ttl in seconds + :param data: arbitrary data i.e. conn_url, api_url, xlog location, state, role, tags, etc... + + There are two mandatory keys in a data: + conn_url: connection string containing host, user and password which could be used to access this member. + api_url: REST API url of patroni instance""" + + @staticmethod + def from_node(index, name, session, data): + """ + >>> Member.from_node(-1, '', '', '{"conn_url": "postgres://foo@bar/postgres"}') is not None + True + >>> Member.from_node(-1, '', '', '{') + Member(index=-1, name='', session='', data={}) + """ + if data.startswith('postgres'): + conn_url, api_url = parse_connection_string(data) + data = {'conn_url': conn_url, 'api_url': api_url} + else: + try: + data = json.loads(data) + except (TypeError, ValueError): + data = {} + return Member(index, name, session, data) + + @property + def conn_url(self): + conn_url = self.data.get('conn_url') + if conn_url: + return conn_url + + conn_kwargs = self.data.get('conn_kwargs') + if conn_kwargs: + conn_url = uri('postgresql', (conn_kwargs.get('host'), conn_kwargs.get('port', 5432))) + self.data['conn_url'] = conn_url + return conn_url + + def conn_kwargs(self, auth=None): + defaults = { + "host": None, + "port": None, + "database": None + } + ret = self.data.get('conn_kwargs') + if ret: + defaults.update(ret) + ret = defaults + else: + conn_url = self.conn_url + if not conn_url: + return {} # due to the invalid conn_url we don't care about authentication parameters + r = urlparse(conn_url) + ret = { + 'host': r.hostname, + 'port': r.port or 5432, + 'database': r.path[1:] + } + self.data['conn_kwargs'] = ret.copy() + + # apply any remaining authentication parameters + if auth and isinstance(auth, dict): + ret.update({k: v for k, v in auth.items() if v is not None}) + if 'username' in auth: + ret['user'] = ret.pop('username') + return ret + + @property + def api_url(self): + return self.data.get('api_url') + + @property + def tags(self): + return self.data.get('tags', {}) + + @property + def nofailover(self): + return self.tags.get('nofailover', False) + + @property + def replicatefrom(self): + return self.tags.get('replicatefrom') + + @property + def clonefrom(self): + return self.tags.get('clonefrom', False) and bool(self.conn_url) + + @property + def state(self): + return self.data.get('state', 'unknown') + + @property + def is_running(self): + return self.state == 'running' + + +class RemoteMember(Member): + """ Represents a remote master for a standby cluster + """ + def __new__(cls, name, data): + return super(RemoteMember, cls).__new__(cls, None, name, None, data) + + @staticmethod + def allowed_keys(): + return ('primary_slot_name', + 'create_replica_methods', + 'restore_command', + 'archive_cleanup_command', + 'recovery_min_apply_delay', + 'no_replication_slot') + + def __getattr__(self, name): + if name in RemoteMember.allowed_keys(): + return self.data.get(name) + + +class Leader(namedtuple('Leader', 'index,session,member')): + + """Immutable object (namedtuple) which represents leader key. + Consists of the following fields: + :param index: modification index of a leader key in a Configuration Store + :param session: either session id or just ttl in seconds + :param member: reference to a `Member` object which represents current leader (see `Cluster.members`)""" + + @property + def name(self): + return self.member.name + + def conn_kwargs(self, auth=None): + return self.member.conn_kwargs(auth) + + @property + def conn_url(self): + return self.member.conn_url + + @property + def data(self): + return self.member.data + + @property + def timeline(self): + return self.data.get('timeline') + + @property + def checkpoint_after_promote(self): + """ + >>> Leader(1, '', Member.from_node(1, '', '', '{"version":"z"}')).checkpoint_after_promote + """ + version = self.data.get('version') + if version: + try: + # 1.5.6 is the last version which doesn't expose checkpoint_after_promote: false + if tuple(map(int, version.split('.'))) > (1, 5, 6): + return self.data['role'] == 'master' and 'checkpoint_after_promote' not in self.data + except Exception: + logger.debug('Failed to parse Patroni version %s', version) + + +class Failover(namedtuple('Failover', 'index,leader,candidate,scheduled_at')): + + """ + >>> 'Failover' in str(Failover.from_node(1, '{"leader": "cluster_leader"}')) + True + >>> 'Failover' in str(Failover.from_node(1, {"leader": "cluster_leader"})) + True + >>> 'Failover' in str(Failover.from_node(1, '{"leader": "cluster_leader", "member": "cluster_candidate"}')) + True + >>> Failover.from_node(1, 'null') is None + False + >>> n = '{"leader": "cluster_leader", "member": "cluster_candidate", "scheduled_at": "2016-01-14T10:09:57.1394Z"}' + >>> 'tzinfo=' in str(Failover.from_node(1, n)) + True + >>> Failover.from_node(1, None) is None + False + >>> Failover.from_node(1, '{}') is None + False + >>> 'abc' in Failover.from_node(1, 'abc:def') + True + """ + @staticmethod + def from_node(index, value): + if isinstance(value, dict): + data = value + elif value: + try: + data = json.loads(value) + if not isinstance(data, dict): + data = {} + except ValueError: + t = [a.strip() for a in value.split(':')] + leader = t[0] + candidate = t[1] if len(t) > 1 else None + return Failover(index, leader, candidate, None) if leader or candidate else None + else: + data = {} + + if data.get('scheduled_at'): + data['scheduled_at'] = dateutil.parser.parse(data['scheduled_at']) + + return Failover(index, data.get('leader'), data.get('member'), data.get('scheduled_at')) + + def __len__(self): + return int(bool(self.leader)) + int(bool(self.candidate)) + + +class ClusterConfig(namedtuple('ClusterConfig', 'index,data,modify_index')): + + @staticmethod + def from_node(index, data, modify_index=None): + """ + >>> ClusterConfig.from_node(1, '{') is None + False + """ + + try: + data = json.loads(data) + except (TypeError, ValueError): + data = None + modify_index = 0 + if not isinstance(data, dict): + data = {} + return ClusterConfig(index, data, index if modify_index is None else modify_index) + + @property + def permanent_slots(self): + return isinstance(self.data, dict) and ( + self.data.get('permanent_replication_slots') or + self.data.get('permanent_slots') or self.data.get('slots') + ) or {} + + @property + def ignore_slots_matchers(self): + return isinstance(self.data, dict) and self.data.get('ignore_slots') or [] + + @property + def max_timelines_history(self): + return self.data.get('max_timelines_history', 0) + + +class SyncState(namedtuple('SyncState', 'index,leader,sync_standby')): + """Immutable object (namedtuple) which represents last observed synhcronous replication state + + :param index: modification index of a synchronization key in a Configuration Store + :param leader: reference to member that was leader + :param sync_standby: synchronous standby list (comma delimited) which are last synchronized to leader + """ + + @staticmethod + def from_node(index, value): + """ + >>> SyncState.from_node(1, None).leader is None + True + >>> SyncState.from_node(1, '{}').leader is None + True + >>> SyncState.from_node(1, '{').leader is None + True + >>> SyncState.from_node(1, '[]').leader is None + True + >>> SyncState.from_node(1, '{"leader": "leader"}').leader == "leader" + True + >>> SyncState.from_node(1, {"leader": "leader"}).leader == "leader" + True + """ + if isinstance(value, dict): + data = value + elif value: + try: + data = json.loads(value) + if not isinstance(data, dict): + data = {} + except (TypeError, ValueError): + data = {} + else: + data = {} + return SyncState(index, data.get('leader'), data.get('sync_standby')) + + @property + def members(self): + """ Returns sync_standby in list """ + return self.sync_standby and self.sync_standby.split(',') or [] + + def matches(self, name): + """ + Returns if a node name matches one of the nodes in the sync state + + >>> s = SyncState(1, 'foo', 'bar,zoo') + >>> s.matches('foo') + True + >>> s.matches('bar') + True + >>> s.matches('zoo') + True + >>> s.matches('baz') + False + >>> s.matches(None) + False + >>> SyncState(1, None, None).matches('foo') + False + """ + return name is not None and name in [self.leader] + self.members + + +class TimelineHistory(namedtuple('TimelineHistory', 'index,value,lines')): + """Object representing timeline history file""" + + @staticmethod + def from_node(index, value): + """ + >>> h = TimelineHistory.from_node(1, 2) + >>> h.lines + [] + """ + try: + lines = json.loads(value) + except (TypeError, ValueError): + lines = None + if not isinstance(lines, list): + lines = [] + return TimelineHistory(index, value, lines) + + +class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_leader_operation,members,failover,sync,history')): + + """Immutable object (namedtuple) which represents PostgreSQL cluster. + Consists of the following fields: + :param initialize: shows whether this cluster has initialization key stored in DC or not. + :param config: global dynamic configuration, reference to `ClusterConfig` object + :param leader: `Leader` object which represents current leader of the cluster + :param last_leader_operation: int or long object containing position of last known leader operation. + This value is stored in `/optime/leader` key + :param members: list of Member object, all PostgreSQL cluster members including leader + :param failover: reference to `Failover` object + :param sync: reference to `SyncState` object, last observed synchronous replication state. + :param history: reference to `TimelineHistory` object + """ + + def is_unlocked(self): + return not (self.leader and self.leader.name) + + def has_member(self, member_name): + return any(m for m in self.members if m.name == member_name) + + def get_member(self, member_name, fallback_to_leader=True): + return ([m for m in self.members if m.name == member_name] or [self.leader if fallback_to_leader else None])[0] + + def get_clone_member(self, exclude): + exclude = [exclude] + [self.leader.name] if self.leader else [] + candidates = [m for m in self.members if m.clonefrom and m.is_running and m.name not in exclude] + return candidates[randint(0, len(candidates) - 1)] if candidates else self.leader + + def check_mode(self, mode): + return bool(self.config and parse_bool(self.config.data.get(mode))) + + def is_paused(self): + return self.check_mode('pause') + + def is_synchronous_mode(self): + return self.check_mode('synchronous_mode') + + def get_replication_slots(self, my_name, role): + # if the replicatefrom tag is set on the member - we should not create the replication slot for it on + # the current master, because that member would replicate from elsewhere. We still create the slot if + # the replicatefrom destination member is currently not a member of the cluster (fallback to the + # master), or if replicatefrom destination member happens to be the current master + use_slots = self.config and self.config.data.get('postgresql', {}).get('use_slots', True) + if role in ('master', 'standby_leader'): + slot_members = [m.name for m in self.members if use_slots and m.name != my_name and + (m.replicatefrom is None or m.replicatefrom == my_name or + not self.has_member(m.replicatefrom))] + permanent_slots = (self.config and self.config.permanent_slots or {}).copy() + else: + # only manage slots for replicas that replicate from this one, except for the leader among them + slot_members = [m.name for m in self.members if use_slots and + m.replicatefrom == my_name and m.name != self.leader.name] + permanent_slots = {} + + slots = {slot_name_from_member_name(name): {'type': 'physical'} for name in slot_members} + + if len(slots) < len(slot_members): + # Find which names are conflicting for a nicer error message + slot_conflicts = defaultdict(list) + for name in slot_members: + slot_conflicts[slot_name_from_member_name(name)].append(name) + logger.error("Following cluster members share a replication slot name: %s", + "; ".join("{} map to {}".format(", ".join(v), k) + for k, v in slot_conflicts.items() if len(v) > 1)) + + # "merge" replication slots for members with permanent_replication_slots + for name, value in permanent_slots.items(): + if not slot_name_re.match(name): + logger.error("Invalid permanent replication slot name '%s'", name) + logger.error("Slot name may only contain lower case letters, numbers, and the underscore chars") + continue + + value = deepcopy(value) if value else {'type': 'physical'} + if isinstance(value, dict): + if 'type' not in value: + value['type'] = 'logical' if value.get('database') and value.get('plugin') else 'physical' + + if value['type'] == 'physical': + # Don't try to create permanent physical replication slot for yourself + if name != slot_name_from_member_name(my_name): + slots[name] = value + continue + elif value['type'] == 'logical' and value.get('database') and value.get('plugin'): + if name in slots: + logger.error("Permanent logical replication slot {'%s': %s} is conflicting with" + + " physical replication slot for cluster member", name, value) + else: + slots[name] = value + continue + + logger.error("Bad value for slot '%s' in permanent_slots: %s", name, permanent_slots[name]) + + return slots + + def has_permanent_logical_slots(self, name): + slots = self.get_replication_slots(name, 'master').values() + return any(v for v in slots if v.get("type") == "logical") + + @property + def timeline(self): + """ + >>> Cluster(0, 0, 0, 0, 0, 0, 0, 0).timeline + 0 + >>> Cluster(0, 0, 0, 0, 0, 0, 0, TimelineHistory.from_node(1, '[]')).timeline + 1 + >>> Cluster(0, 0, 0, 0, 0, 0, 0, TimelineHistory.from_node(1, '[["a"]]')).timeline + 0 + """ + if self.history: + if self.history.lines: + try: + return int(self.history.lines[-1][0]) + 1 + except Exception: + logger.error('Failed to parse cluster history from DCS: %s', self.history.lines) + elif self.history.value == '[]': + return 1 + return 0 + + +@six.add_metaclass(abc.ABCMeta) +class AbstractDCS(object): + + _INITIALIZE = 'initialize' + _CONFIG = 'config' + _LEADER = 'leader' + _FAILOVER = 'failover' + _HISTORY = 'history' + _MEMBERS = 'members/' + _OPTIME = 'optime' + _LEADER_OPTIME = _OPTIME + '/' + _LEADER + _SYNC = 'sync' + + def __init__(self, config): + """ + :param config: dict, reference to config section of selected DCS. + i.e.: `zookeeper` for zookeeper, `etcd` for etcd, etc... + """ + self._name = config['name'] + self._base_path = re.sub('/+', '/', '/'.join(['', config.get('namespace', 'service'), config['scope']])) + self._set_loop_wait(config.get('loop_wait', 10)) + + self._ctl = bool(config.get('patronictl', False)) + self._cluster = None + self._cluster_valid_till = 0 + self._cluster_thread_lock = Lock() + self._last_leader_operation = '' + self.event = Event() + + def client_path(self, path): + return '/'.join([self._base_path, path.lstrip('/')]) + + @property + def initialize_path(self): + return self.client_path(self._INITIALIZE) + + @property + def config_path(self): + return self.client_path(self._CONFIG) + + @property + def members_path(self): + return self.client_path(self._MEMBERS) + + @property + def member_path(self): + return self.client_path(self._MEMBERS + self._name) + + @property + def leader_path(self): + return self.client_path(self._LEADER) + + @property + def failover_path(self): + return self.client_path(self._FAILOVER) + + @property + def history_path(self): + return self.client_path(self._HISTORY) + + @property + def leader_optime_path(self): + return self.client_path(self._LEADER_OPTIME) + + @property + def sync_path(self): + return self.client_path(self._SYNC) + + @abc.abstractmethod + def set_ttl(self, ttl): + """Set the new ttl value for leader key""" + + @abc.abstractmethod + def ttl(self): + """Get new ttl value""" + + @abc.abstractmethod + def set_retry_timeout(self, retry_timeout): + """Set the new value for retry_timeout""" + + def _set_loop_wait(self, loop_wait): + self._loop_wait = loop_wait + + def reload_config(self, config): + self._set_loop_wait(config['loop_wait']) + self.set_ttl(config['ttl']) + self.set_retry_timeout(config['retry_timeout']) + + @property + def loop_wait(self): + return self._loop_wait + + @abc.abstractmethod + def _load_cluster(self): + """Internally this method should build `Cluster` object which + represents current state and topology of the cluster in DCS. + this method supposed to be called only by `get_cluster` method. + + raise `~DCSError` in case of communication or other problems with DCS. + If the current node was running as a master and exception raised, + instance would be demoted.""" + + def _bypass_caches(self): + """Used only in zookeeper""" + + def get_cluster(self, force=False): + if force: + self._bypass_caches() + try: + cluster = self._load_cluster() + except Exception: + self.reset_cluster() + raise + + with self._cluster_thread_lock: + self._cluster = cluster + self._cluster_valid_till = time.time() + self.ttl + return cluster + + @property + def cluster(self): + with self._cluster_thread_lock: + return self._cluster if self._cluster_valid_till > time.time() else None + + def reset_cluster(self): + with self._cluster_thread_lock: + self._cluster = None + self._cluster_valid_till = 0 + + @abc.abstractmethod + def _write_leader_optime(self, last_operation): + """write current xlog location into `/optime/leader` key in DCS + :param last_operation: absolute xlog location in bytes + :returns: `!True` on success.""" + + def write_leader_optime(self, last_operation): + if self._last_leader_operation != last_operation and self._write_leader_optime(last_operation): + self._last_leader_operation = last_operation + + @abc.abstractmethod + def _update_leader(self): + """Update leader key (or session) ttl + + :returns: `!True` if leader key (or session) has been updated successfully. + If not, `!False` must be returned and current instance would be demoted. + + You have to use CAS (Compare And Swap) operation in order to update leader key, + for example for etcd `prevValue` parameter must be used.""" + + def update_leader(self, last_operation, access_is_restricted=False): + """Update leader key (or session) ttl and optime/leader + + :param last_operation: absolute xlog location in bytes + :returns: `!True` if leader key (or session) has been updated successfully. + If not, `!False` must be returned and current instance would be demoted.""" + + ret = self._update_leader() + if ret and last_operation: + self.write_leader_optime(last_operation) + return ret + + @abc.abstractmethod + def attempt_to_acquire_leader(self, permanent=False): + """Attempt to acquire leader lock + This method should create `/leader` key with value=`~self._name` + :param permanent: if set to `!True`, the leader key will never expire. + Used in patronictl for the external master + :returns: `!True` if key has been created successfully. + + Key must be created atomically. In case if key already exists it should not be + overwritten and `!False` must be returned""" + + @abc.abstractmethod + def set_failover_value(self, value, index=None): + """Create or update `/failover` key""" + + def manual_failover(self, leader, candidate, scheduled_at=None, index=None): + failover_value = {} + if leader: + failover_value['leader'] = leader + + if candidate: + failover_value['member'] = candidate + + if scheduled_at: + failover_value['scheduled_at'] = scheduled_at.isoformat() + return self.set_failover_value(json.dumps(failover_value, separators=(',', ':')), index) + + @abc.abstractmethod + def set_config_value(self, value, index=None): + """Create or update `/config` key""" + + @abc.abstractmethod + def touch_member(self, data, permanent=False): + """Update member key in DCS. + This method should create or update key with the name = '/members/' + `~self._name` + and value = data in a given DCS. + + :param data: information about instance (including connection strings) + :param ttl: ttl for member key, optional parameter. If it is None `~self.member_ttl will be used` + :param permanent: if set to `!True`, the member key will never expire. + Used in patronictl for the external master. + :returns: `!True` on success otherwise `!False` + """ + + @abc.abstractmethod + def take_leader(self): + """This method should create leader key with value = `~self._name` and ttl=`~self.ttl` + Since it could be called only on initial cluster bootstrap it could create this key regardless, + overwriting the key if necessary.""" + + @abc.abstractmethod + def initialize(self, create_new=True, sysid=""): + """Race for cluster initialization. + + :param create_new: False if the key should already exist (in the case we are setting the system_id) + :param sysid: PostgreSQL cluster system identifier, if specified, is written to the key + :returns: `!True` if key has been created successfully. + + this method should create atomically initialize key and return `!True` + otherwise it should return `!False`""" + + @abc.abstractmethod + def _delete_leader(self): + """Remove leader key from DCS. + This method should remove leader key if current instance is the leader""" + + def delete_leader(self, last_operation=None): + """Update optime/leader and voluntarily remove leader key from DCS. + This method should remove leader key if current instance is the leader. + :param last_operation: latest checkpoint location in bytes""" + + if last_operation: + self.write_leader_optime(last_operation) + return self._delete_leader() + + @abc.abstractmethod + def cancel_initialization(self): + """ Removes the initialize key for a cluster """ + + @abc.abstractmethod + def delete_cluster(self): + """Delete cluster from DCS""" + + @staticmethod + def sync_state(leader, sync_standby): + """Build sync_state dict + sync_standby dictionary key being kept for backward compatibility + """ + return {'leader': leader, 'sync_standby': sync_standby and ','.join(sorted(sync_standby)) or None} + + def write_sync_state(self, leader, sync_standby, index=None): + sync_value = self.sync_state(leader, sync_standby) + return self.set_sync_state_value(json.dumps(sync_value, separators=(',', ':')), index) + + @abc.abstractmethod + def set_history_value(self, value): + """""" + + @abc.abstractmethod + def set_sync_state_value(self, value, index=None): + """""" + + @abc.abstractmethod + def delete_sync_state(self, index=None): + """""" + + def watch(self, leader_index, timeout): + """If the current node is a master it should just sleep. + Any other node should watch for changes of leader key with a given timeout + + :param leader_index: index of a leader key + :param timeout: timeout in seconds + :returns: `!True` if you would like to reschedule the next run of ha cycle""" + + self.event.wait(timeout) + return self.event.isSet() diff --git a/patroni-for-openGauss/dcs/consul.py b/patroni-for-openGauss/dcs/consul.py new file mode 100644 index 0000000000000000000000000000000000000000..a9acb08b8e99fb52f1281acdfdde2647b16dff71 --- /dev/null +++ b/patroni-for-openGauss/dcs/consul.py @@ -0,0 +1,555 @@ +from __future__ import absolute_import +import json +import logging +import os +import re +import socket +import ssl +import time +import urllib3 + +from collections import namedtuple +from consul import ConsulException, NotFound, base +from urllib3.exceptions import HTTPError +from six.moves.urllib.parse import urlencode, urlparse, quote +from six.moves.http_client import HTTPException + +from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState, TimelineHistory +from ..exceptions import DCSError +from ..utils import deep_compare, parse_bool, Retry, RetryFailedError, split_host_port, uri, USER_AGENT + +logger = logging.getLogger(__name__) + + +class ConsulError(DCSError): + pass + + +class ConsulInternalError(ConsulException): + """An internal Consul server error occurred""" + + +class InvalidSessionTTL(ConsulException): + """Session TTL is too small or too big""" + + +class InvalidSession(ConsulException): + """invalid session""" + + +Response = namedtuple('Response', 'code,headers,body,content') + + +class HTTPClient(object): + + def __init__(self, host='127.0.0.1', port=8500, token=None, scheme='http', verify=True, cert=None, ca_cert=None): + self.token = token + self._read_timeout = 10 + self.base_uri = uri(scheme, (host, port)) + kwargs = {} + if cert: + if isinstance(cert, tuple): + # Key and cert are separate + kwargs['cert_file'] = cert[0] + kwargs['key_file'] = cert[1] + else: + # combined certificate + kwargs['cert_file'] = cert + if ca_cert: + kwargs['ca_certs'] = ca_cert + kwargs['cert_reqs'] = ssl.CERT_REQUIRED if verify or ca_cert else ssl.CERT_NONE + self.http = urllib3.PoolManager(num_pools=10, maxsize=10, **kwargs) + self._ttl = None + + def set_read_timeout(self, timeout): + self._read_timeout = timeout/3.0 + + @property + def ttl(self): + return self._ttl + + def set_ttl(self, ttl): + ret = self._ttl != ttl + self._ttl = ttl + return ret + + @staticmethod + def response(response): + content = response.data + body = content.decode('utf-8') + if response.status == 500: + msg = '{0} {1}'.format(response.status, body) + if body.startswith('Invalid Session TTL'): + raise InvalidSessionTTL(msg) + elif body.startswith('invalid session'): + raise InvalidSession(msg) + else: + raise ConsulInternalError(msg) + return Response(response.status, response.headers, body, content) + + def uri(self, path, params=None): + return '{0}{1}{2}'.format(self.base_uri, path, params and '?' + urlencode(params) or '') + + def __getattr__(self, method): + if method not in ('get', 'post', 'put', 'delete'): + raise AttributeError("HTTPClient instance has no attribute '{0}'".format(method)) + + def wrapper(callback, path, params=None, data='', headers=None): + # python-consul doesn't allow to specify ttl smaller then 10 seconds + # because session_ttl_min defaults to 10s, so we have to do this ugly dirty hack... + if method == 'put' and path == '/v1/session/create': + ttl = '"ttl": "{0}s"'.format(self._ttl) + if not data or data == '{}': + data = '{' + ttl + '}' + else: + data = data[:-1] + ', ' + ttl + '}' + if isinstance(params, list): # starting from v1.1.0 python-consul switched from `dict` to `list` for params + params = {k: v for k, v in params} + kwargs = {'retries': 0, 'preload_content': False, 'body': data} + if method == 'get' and isinstance(params, dict) and 'index' in params: + timeout = float(params['wait'][:-1]) if 'wait' in params else 300 + # According to the documentation a small random amount of additional wait time is added to the + # supplied maximum wait time to spread out the wake up time of any concurrent requests. This adds + # up to wait / 16 additional time to the maximum duration. Since our goal is actually getting a + # response rather read timeout we will add to the timeout a sligtly bigger value. + kwargs['timeout'] = timeout + max(timeout/15.0, 1) + else: + kwargs['timeout'] = self._read_timeout + kwargs['headers'] = (headers or {}).copy() + kwargs['headers'].update(urllib3.make_headers(user_agent=USER_AGENT)) + token = params.pop('token', self.token) if isinstance(params, dict) else self.token + if token: + kwargs['headers']['X-Consul-Token'] = token + return callback(self.response(self.http.request(method.upper(), self.uri(path, params), **kwargs))) + return wrapper + + +class ConsulClient(base.Consul): + + def __init__(self, *args, **kwargs): + self._cert = kwargs.pop('cert', None) + self._ca_cert = kwargs.pop('ca_cert', None) + self.token = kwargs.get('token') + super(ConsulClient, self).__init__(*args, **kwargs) + + def http_connect(self, *args, **kwargs): + kwargs.update(dict(zip(['host', 'port', 'scheme', 'verify'], args))) + if self._cert: + kwargs['cert'] = self._cert + if self._ca_cert: + kwargs['ca_cert'] = self._ca_cert + if self.token: + kwargs['token'] = self.token + return HTTPClient(**kwargs) + + def connect(self, *args, **kwargs): + return self.http_connect(*args, **kwargs) + + def reload_config(self, config): + self.http.token = self.token = config.get('token') + self.consistency = config.get('consistency', 'default') + self.dc = config.get('dc') + + +def catch_consul_errors(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except (RetryFailedError, ConsulException, HTTPException, HTTPError, socket.error, socket.timeout): + return False + return wrapper + + +def force_if_last_failed(func): + def wrapper(*args, **kwargs): + if wrapper.last_result is False: + kwargs['force'] = True + wrapper.last_result = func(*args, **kwargs) + return wrapper.last_result + + wrapper.last_result = None + return wrapper + + +def service_name_from_scope_name(scope_name): + """Translate scope name to service name which can be used in dns. + + 230 = 253 - len('replica.') - len('.service.consul') + """ + + def replace_char(match): + c = match.group(0) + return '-' if c in '. _' else "u{:04d}".format(ord(c)) + + service_name = re.sub(r'[^a-z0-9\-]', replace_char, scope_name.lower()) + return service_name[0:230] + + +class Consul(AbstractDCS): + + def __init__(self, config): + super(Consul, self).__init__(config) + self._scope = config['scope'] + self._session = None + self.__do_not_watch = False + self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1, + retry_exceptions=(ConsulInternalError, HTTPException, + HTTPError, socket.error, socket.timeout)) + + kwargs = {} + if 'url' in config: + r = urlparse(config['url']) + config.update({'scheme': r.scheme, 'host': r.hostname, 'port': r.port or 8500}) + elif 'host' in config: + host, port = split_host_port(config.get('host', '127.0.0.1:8500'), 8500) + config['host'] = host + if 'port' not in config: + config['port'] = int(port) + + if config.get('cacert'): + config['ca_cert'] = config.pop('cacert') + + if config.get('key') and config.get('cert'): + config['cert'] = (config['cert'], config['key']) + + config_keys = ('host', 'port', 'token', 'scheme', 'cert', 'ca_cert', 'dc', 'consistency') + kwargs = {p: config.get(p) for p in config_keys if config.get(p)} + + verify = config.get('verify') + if not isinstance(verify, bool): + verify = parse_bool(verify) + if isinstance(verify, bool): + kwargs['verify'] = verify + + self._client = ConsulClient(**kwargs) + self.set_retry_timeout(config['retry_timeout']) + self.set_ttl(config.get('ttl') or 30) + self._last_session_refresh = 0 + self.__session_checks = config.get('checks', []) + self._register_service = config.get('register_service', False) + if self._register_service: + self._service_tags = config.get('service_tags', []) + self._service_name = service_name_from_scope_name(self._scope) + if self._scope != self._service_name: + logger.warning('Using %s as consul service name instead of scope name %s', self._service_name, + self._scope) + self._service_check_interval = config.get('service_check_interval', '5s') + if not self._ctl: + self.create_session() + + def retry(self, *args, **kwargs): + return self._retry.copy()(*args, **kwargs) + + def create_session(self): + while not self._session: + try: + self.refresh_session() + except ConsulError: + logger.info('waiting on consul') + time.sleep(5) + + def reload_config(self, config): + super(Consul, self).reload_config(config) + self._client.reload_config(config.get('consul', {})) + + def set_ttl(self, ttl): + if self._client.http.set_ttl(ttl/2.0): # Consul multiplies the TTL by 2x + self._session = None + self.__do_not_watch = True + + @property + def ttl(self): + return self._client.http.ttl + + def set_retry_timeout(self, retry_timeout): + self._retry.deadline = retry_timeout + self._client.http.set_read_timeout(retry_timeout) + + def adjust_ttl(self): + try: + settings = self._client.agent.self() + min_ttl = (settings['Config']['SessionTTLMin'] or 10000000000)/1000000000.0 + logger.warning('Changing Session TTL from %s to %s', self._client.http.ttl, min_ttl) + self._client.http.set_ttl(min_ttl) + except Exception: + logger.exception('adjust_ttl') + + def _do_refresh_session(self): + """:returns: `!True` if it had to create new session""" + if self._session and self._last_session_refresh + self._loop_wait > time.time(): + return False + + if self._session: + try: + self._client.session.renew(self._session) + except NotFound: + self._session = None + ret = not self._session + if ret: + try: + self._session = self._client.session.create(name=self._scope + '-' + self._name, + checks=self.__session_checks, + lock_delay=0.001, behavior='delete') + except InvalidSessionTTL: + logger.exception('session.create') + self.adjust_ttl() + raise + + self._last_session_refresh = time.time() + return ret + + def refresh_session(self): + try: + return self.retry(self._do_refresh_session) + except (ConsulException, RetryFailedError): + logger.exception('refresh_session') + raise ConsulError('Failed to renew/create session') + + def client_path(self, path): + return super(Consul, self).client_path(path)[1:] + + @staticmethod + def member(node): + return Member.from_node(node['ModifyIndex'], os.path.basename(node['Key']), node.get('Session'), node['Value']) + + def _load_cluster(self): + try: + path = self.client_path('/') + _, results = self.retry(self._client.kv.get, path, recurse=True) + + if results is None: + raise NotFound + + nodes = {} + for node in results: + node['Value'] = (node['Value'] or b'').decode('utf-8') + nodes[node['Key'][len(path):].lstrip('/')] = node + + # get initialize flag + initialize = nodes.get(self._INITIALIZE) + initialize = initialize and initialize['Value'] + + # get global dynamic configuration + config = nodes.get(self._CONFIG) + config = config and ClusterConfig.from_node(config['ModifyIndex'], config['Value']) + + # get timeline history + history = nodes.get(self._HISTORY) + history = history and TimelineHistory.from_node(history['ModifyIndex'], history['Value']) + + # get last leader operation + last_leader_operation = nodes.get(self._LEADER_OPTIME) + last_leader_operation = 0 if last_leader_operation is None else int(last_leader_operation['Value']) + + # get list of members + members = [self.member(n) for k, n in nodes.items() if k.startswith(self._MEMBERS) and k.count('/') == 1] + + # get leader + leader = nodes.get(self._LEADER) + if not self._ctl and leader and leader['Value'] == self._name \ + and self._session != leader.get('Session', 'x'): + logger.info('I am leader but not owner of the session. Removing leader node') + self._client.kv.delete(self.leader_path, cas=leader['ModifyIndex']) + leader = None + + if leader: + member = Member(-1, leader['Value'], None, {}) + member = ([m for m in members if m.name == leader['Value']] or [member])[0] + leader = Leader(leader['ModifyIndex'], leader.get('Session'), member) + + # failover key + failover = nodes.get(self._FAILOVER) + if failover: + failover = Failover.from_node(failover['ModifyIndex'], failover['Value']) + + # get synchronization state + sync = nodes.get(self._SYNC) + sync = SyncState.from_node(sync and sync['ModifyIndex'], sync and sync['Value']) + + return Cluster(initialize, config, leader, last_leader_operation, members, failover, sync, history) + except NotFound: + return Cluster(None, None, None, None, [], None, None, None) + except Exception: + logger.exception('get_cluster') + raise ConsulError('Consul is not responding properly') + + @catch_consul_errors + def touch_member(self, data, permanent=False): + cluster = self.cluster + member = cluster and cluster.get_member(self._name, fallback_to_leader=False) + + try: + create_member = not permanent and self.refresh_session() + except DCSError: + return False + + if member and (create_member or member.session != self._session): + self._client.kv.delete(self.member_path) + create_member = True + + if not create_member and member and deep_compare(data, member.data): + return True + + try: + args = {} if permanent else {'acquire': self._session} + self._client.kv.put(self.member_path, json.dumps(data, separators=(',', ':')), **args) + if self._register_service: + self.update_service(not create_member and member and member.data or {}, data) + return True + except InvalidSession: + self._session = None + logger.error('Our session disappeared from Consul, can not "touch_member"') + except Exception: + logger.exception('touch_member') + return False + + @catch_consul_errors + def register_service(self, service_name, **kwargs): + logger.info('Register service %s, params %s', service_name, kwargs) + return self._client.agent.service.register(service_name, **kwargs) + + @catch_consul_errors + def deregister_service(self, service_id): + logger.info('Deregister service %s', service_id) + # service_id can contain special characters, but is used as part of uri in deregister request + service_id = quote(service_id) + return self._client.agent.service.deregister(service_id) + + def _update_service(self, data): + service_name = self._service_name + role = data['role'].replace('_', '-') + state = data['state'] + api_parts = urlparse(data['api_url']) + api_parts = api_parts._replace(path='/{0}'.format(role)) + conn_parts = urlparse(data['conn_url']) + check = base.Check.http(api_parts.geturl(), self._service_check_interval, + deregister='{0}s'.format(self._client.http.ttl * 10)) + tags = self._service_tags[:] + tags.append(role) + params = { + 'service_id': '{0}/{1}'.format(self._scope, self._name), + 'address': conn_parts.hostname, + 'port': conn_parts.port, + 'check': check, + 'tags': tags + } + + if state == 'stopped': + return self.deregister_service(params['service_id']) + + if role in ['master', 'replica', 'standby-leader']: + if state != 'running': + return + return self.register_service(service_name, **params) + + logger.warning('Could not register service: unknown role type %s', role) + + @force_if_last_failed + def update_service(self, old_data, new_data, force=False): + update = False + + for key in ['role', 'api_url', 'conn_url', 'state']: + if key not in new_data: + logger.warning('Could not register service: not enough params in member data') + return + if old_data.get(key) != new_data[key]: + update = True + + if force or update: + return self._update_service(new_data) + + @catch_consul_errors + def _do_attempt_to_acquire_leader(self, permanent): + try: + kwargs = {} if permanent else {'acquire': self._session} + return self.retry(self._client.kv.put, self.leader_path, self._name, **kwargs) + except InvalidSession: + self._session = None + logger.error('Our session disappeared from Consul. Will try to get a new one and retry attempt') + self.refresh_session() + return self.retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session) + + def attempt_to_acquire_leader(self, permanent=False): + if not self._session and not permanent: + self.refresh_session() + + ret = self._do_attempt_to_acquire_leader(permanent) + if not ret: + logger.info('Could not take out TTL lock') + + return ret + + def take_leader(self): + return self.attempt_to_acquire_leader() + + @catch_consul_errors + def set_failover_value(self, value, index=None): + return self._client.kv.put(self.failover_path, value, cas=index) + + @catch_consul_errors + def set_config_value(self, value, index=None): + return self._client.kv.put(self.config_path, value, cas=index) + + @catch_consul_errors + def _write_leader_optime(self, last_operation): + return self._client.kv.put(self.leader_optime_path, last_operation) + + @catch_consul_errors + def _update_leader(self): + if self._session: + self.retry(self._client.session.renew, self._session) + self._last_session_refresh = time.time() + return bool(self._session) + + @catch_consul_errors + def initialize(self, create_new=True, sysid=''): + kwargs = {'cas': 0} if create_new else {} + return self.retry(self._client.kv.put, self.initialize_path, sysid, **kwargs) + + @catch_consul_errors + def cancel_initialization(self): + return self.retry(self._client.kv.delete, self.initialize_path) + + @catch_consul_errors + def delete_cluster(self): + return self.retry(self._client.kv.delete, self.client_path(''), recurse=True) + + @catch_consul_errors + def set_history_value(self, value): + return self._client.kv.put(self.history_path, value) + + @catch_consul_errors + def _delete_leader(self): + cluster = self.cluster + if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name: + return self._client.kv.delete(self.leader_path, cas=cluster.leader.index) + + @catch_consul_errors + def set_sync_state_value(self, value, index=None): + return self.retry(self._client.kv.put, self.sync_path, value, cas=index) + + @catch_consul_errors + def delete_sync_state(self, index=None): + return self.retry(self._client.kv.delete, self.sync_path, cas=index) + + def watch(self, leader_index, timeout): + self._last_session_refresh = 0 + if self.__do_not_watch: + self.__do_not_watch = False + return True + + if leader_index: + end_time = time.time() + timeout + while timeout >= 1: + try: + idx, _ = self._client.kv.get(self.leader_path, index=leader_index, wait=str(timeout) + 's') + return str(idx) != str(leader_index) + except (ConsulException, HTTPException, HTTPError, socket.error, socket.timeout): + logger.exception('watch') + + timeout = end_time - time.time() + + try: + return super(Consul, self).watch(None, timeout) + finally: + self.event.clear() diff --git a/patroni-for-openGauss/dcs/etcd.py b/patroni-for-openGauss/dcs/etcd.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e20ba0959b3390ca1887f2ff1f1b23e56f24eb --- /dev/null +++ b/patroni-for-openGauss/dcs/etcd.py @@ -0,0 +1,737 @@ +from __future__ import absolute_import +import abc +import etcd +import json +import logging +import os +import urllib3.util.connection +import random +import six +import socket +import time + +from dns.exception import DNSException +from dns import resolver +from urllib3 import Timeout +from urllib3.exceptions import HTTPError, ReadTimeoutError, ProtocolError +from six.moves.queue import Queue +from six.moves.http_client import HTTPException +from six.moves.urllib_parse import urlparse +from threading import Thread + +from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState, TimelineHistory +from ..exceptions import DCSError +from ..request import get as requests_get +from ..utils import Retry, RetryFailedError, split_host_port, uri, USER_AGENT + +logger = logging.getLogger(__name__) + + +class EtcdRaftInternal(etcd.EtcdException): + """Raft Internal Error""" + + +class EtcdError(DCSError): + pass + + +class DnsCachingResolver(Thread): + + def __init__(self, cache_time=600.0, cache_fail_time=30.0): + super(DnsCachingResolver, self).__init__() + self._cache = {} + self._cache_time = cache_time + self._cache_fail_time = cache_fail_time + self._resolve_queue = Queue() + self.daemon = True + self.start() + + def run(self): + while True: + (host, port), attempt = self._resolve_queue.get() + response = self._do_resolve(host, port) + if response: + self._cache[(host, port)] = (time.time(), response) + else: + if attempt < 10: + self.resolve_async(host, port, attempt + 1) + time.sleep(1) + + def resolve(self, host, port): + current_time = time.time() + cached_time, response = self._cache.get((host, port), (0, [])) + time_passed = current_time - cached_time + if time_passed > self._cache_time or (not response and time_passed > self._cache_fail_time): + new_response = self._do_resolve(host, port) + if new_response: + self._cache[(host, port)] = (current_time, new_response) + response = new_response + return response + + def resolve_async(self, host, port, attempt=0): + self._resolve_queue.put(((host, port), attempt)) + + def remove(self, host, port): + self._cache.pop((host, port), None) + + @staticmethod + def _do_resolve(host, port): + try: + return socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP) + except Exception as e: + logger.warning('failed to resolve host %s: %s', host, e) + return [] + + +@six.add_metaclass(abc.ABCMeta) +class AbstractEtcdClientWithFailover(etcd.Client): + + def __init__(self, config, dns_resolver, cache_ttl=300): + self._dns_resolver = dns_resolver + self.set_machines_cache_ttl(cache_ttl) + self._machines_cache_updated = 0 + args = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies', 'username', 'password', + 'cert', 'ca_cert') if config.get(p)} + super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **args) + # For some reason python3-etcd on debian and ubuntu are not based on the latest version + # Workaround for the case when https://github.com/jplana/python-etcd/pull/196 is not applied + self.http.connection_pool_kw.pop('ssl_version', None) + self._config = config + self._initial_machines_cache = [] + self._load_machines_cache() + self._allow_reconnect = True + # allow passing retry argument to api_execute in params + self._comparison_conditions.add('retry') + self._read_options.add('retry') + self._del_conditions.add('retry') + + def _calculate_timeouts(self, etcd_nodes, timeout=None): + """Calculate a request timeout and number of retries per single etcd node. + In case if the timeout per node is too small (less than one second) we will reduce the number of nodes. + For the cluster with only one node we will try to do 2 retries. + For clusters with 2 nodes we will try to do 1 retry for every node. + No retries for clusters with 3 or more nodes. We better rely on switching to a different node.""" + + per_node_timeout = timeout = float(timeout or self.read_timeout) + + max_retries = 4 - min(etcd_nodes, 3) + per_node_retries = 1 + min_timeout = 1.0 + + while etcd_nodes > 0: + per_node_timeout = float(timeout) / etcd_nodes + if per_node_timeout >= min_timeout: + # for small clusters we will try to do more than on try on every node + while per_node_retries < max_retries and per_node_timeout / (per_node_retries + 1) >= min_timeout: + per_node_retries += 1 + per_node_timeout /= per_node_retries + break + # if the timeout per one node is to small try to reduce number of nodes + etcd_nodes -= 1 + max_retries = 1 + + return etcd_nodes, per_node_timeout, per_node_retries - 1 + + def reload_config(self, config): + self.username = config.get('username') + self.password = config.get('password') + + def _get_headers(self): + basic_auth = ':'.join((self.username, self.password)) if self.username and self.password else None + return urllib3.make_headers(basic_auth=basic_auth, user_agent=USER_AGENT) + + def _prepare_common_parameters(self, etcd_nodes, timeout=None): + kwargs = {'headers': self._get_headers(), 'redirect': self.allow_redirect, 'preload_content': False} + + if timeout is not None: + kwargs.update(retries=0, timeout=timeout) + else: + _, per_node_timeout, per_node_retries = self._calculate_timeouts(etcd_nodes) + connect_timeout = max(1, per_node_timeout/2) + kwargs.update(timeout=Timeout(connect=connect_timeout, total=per_node_timeout), retries=per_node_retries) + return kwargs + + def set_machines_cache_ttl(self, cache_ttl): + self._machines_cache_ttl = cache_ttl + + @abc.abstractmethod + def _prepare_get_members(self, etcd_nodes): + """returns: request parameters""" + + @abc.abstractmethod + def _get_members(self, base_uri, **kwargs): + """returns: list of clientURLs""" + + @property + def machines_cache(self): + base_uri, cache = self._base_uri, self._machines_cache + return ([base_uri] if base_uri in cache else []) + [machine for machine in cache if machine != base_uri] + + @property + def machines(self): + """Original `machines` method(property) of `etcd.Client` class raise exception + when it failed to get list of etcd cluster members. This method is being called + only when request failed on one of the etcd members during `api_execute` call. + For us it's more important to execute original request rather then get new topology + of etcd cluster. So we will catch this exception and return empty list of machines. + Later, during next `api_execute` call we will forcefully update machines_cache. + + Also this method implements the same timeout-retry logic as `api_execute`, because + the original method was retrying 2 times with the `read_timeout` on each node.""" + + machines_cache = self.machines_cache + kwargs = self._prepare_get_members(len(machines_cache)) + + for base_uri in machines_cache: + try: + machines = list(self._get_members(base_uri, **kwargs)) + logger.debug("Retrieved list of machines: %s", machines) + if machines: + random.shuffle(machines) + self._update_dns_cache(self._dns_resolver.resolve_async, machines) + return machines + except Exception as e: + self.http.clear() + logger.error("Failed to get list of machines from %s%s: %r", base_uri, self.version_prefix, e) + + raise etcd.EtcdConnectionFailed('No more machines in the cluster') + + def set_read_timeout(self, timeout): + self._read_timeout = timeout + + def _do_http_request(self, retry, machines_cache, request_executor, method, path, fields=None, **kwargs): + if fields is not None: + kwargs['fields'] = fields + some_request_failed = False + for i, base_uri in enumerate(machines_cache): + if i > 0: + logger.info("Retrying on %s", base_uri) + try: + response = request_executor(method, base_uri + path, **kwargs) + response.data.decode('utf-8') + if some_request_failed: + self.set_base_uri(base_uri) + self._refresh_machines_cache() + return response + except (HTTPError, HTTPException, socket.error, socket.timeout) as e: + self.http.clear() + # switch to the next etcd node because we don't know exactly what happened, + # whether the key didn't received an update or there is a network problem. + if not retry and i + 1 < len(machines_cache): + self.set_base_uri(machines_cache[i + 1]) + if (isinstance(fields, dict) and fields.get("wait") == "true" and + isinstance(e, (ReadTimeoutError, ProtocolError))): + logger.debug("Watch timed out.") + raise etcd.EtcdWatchTimedOut("Watch timed out: {0}".format(e), cause=e) + logger.error("Request to server %s failed: %r", base_uri, e) + logger.info("Reconnection allowed, looking for another server.") + if not retry: + raise etcd.EtcdException('{0} {1} request failed'.format(method, path)) + some_request_failed = True + + raise etcd.EtcdConnectionFailed('No more machines in the cluster') + + @abc.abstractmethod + def _prepare_request(self, kwargs, params=None, method=None): + """returns: request_executor""" + + def api_execute(self, path, method, params=None, timeout=None): + retry = params.pop('retry', None) if isinstance(params, dict) else None + + # Update machines_cache if previous attempt of update has failed + if self._update_machines_cache: + self._load_machines_cache() + elif not self._use_proxies and time.time() - self._machines_cache_updated > self._machines_cache_ttl: + self._refresh_machines_cache() + + machines_cache = self.machines_cache + etcd_nodes = len(machines_cache) + + kwargs = self._prepare_common_parameters(etcd_nodes, timeout) + request_executor = self._prepare_request(kwargs, params, method) + + while True: + try: + response = self._do_http_request(retry, machines_cache, request_executor, method, path, **kwargs) + return self._handle_server_response(response) + except etcd.EtcdWatchTimedOut: + raise + except etcd.EtcdConnectionFailed as ex: + try: + if self._load_machines_cache(): + machines_cache = self.machines_cache + etcd_nodes = len(machines_cache) + except Exception as e: + logger.debug('Failed to update list of etcd nodes: %r', e) + sleeptime = retry.sleeptime + remaining_time = retry.stoptime - sleeptime - time.time() + nodes, timeout, retries = self._calculate_timeouts(etcd_nodes, remaining_time) + if nodes == 0: + self._update_machines_cache = True + raise ex + retry.sleep_func(sleeptime) + retry.update_delay() + # We still have some time left. Partially reduce `machines_cache` and retry request + kwargs.update(timeout=Timeout(connect=max(1, timeout/2), total=timeout), retries=retries) + machines_cache = machines_cache[:nodes] + + @staticmethod + def get_srv_record(host): + try: + return [(r.target.to_text(True), r.port) for r in resolver.query(host, 'SRV')] + except DNSException: + return [] + + def _get_machines_cache_from_srv(self, srv): + """Fetch list of etcd-cluster member by resolving _etcd-server._tcp. SRV record. + This record should contain list of host and peer ports which could be used to run + 'GET http://{host}:{port}/members' request (peer protocol)""" + + ret = [] + for r in ['-client-ssl', '-client', '-ssl', '', '-server-ssl', '-server']: + protocol = 'https' if '-ssl' in r else 'http' + endpoint = '/members' if '-server' in r else '' + for host, port in self.get_srv_record('_etcd{0}._tcp.{1}'.format(r, srv)): + url = uri(protocol, (host, port), endpoint) + if endpoint: + try: + response = requests_get(url, timeout=self.read_timeout, verify=False) + if response.status < 400: + for member in json.loads(response.data.decode('utf-8')): + ret.extend(member['clientURLs']) + break + except Exception: + logger.exception('GET %s', url) + else: + ret.append(url) + if ret: + self._protocol = protocol + break + else: + logger.warning('Can not resolve SRV for %s', srv) + return list(set(ret)) + + def _get_machines_cache_from_dns(self, host, port): + """One host might be resolved into multiple ip addresses. We will make list out of it""" + if self.protocol == 'http': + ret = map(lambda res: uri(self.protocol, res[-1][:2]), self._dns_resolver.resolve(host, port)) + if ret: + return list(set(ret)) + return [uri(self.protocol, (host, port))] + + def _get_machines_cache_from_config(self): + if 'proxy' in self._config: + return [uri(self.protocol, (self._config['host'], self._config['port']))] + + machines_cache = [] + if 'srv' in self._config: + machines_cache = self._get_machines_cache_from_srv(self._config['srv']) + + if not machines_cache and 'hosts' in self._config: + machines_cache = list(self._config['hosts']) + + if not machines_cache and 'host' in self._config: + machines_cache = self._get_machines_cache_from_dns(self._config['host'], self._config['port']) + return machines_cache + + @staticmethod + def _update_dns_cache(func, machines): + for url in machines: + r = urlparse(url) + port = r.port or (443 if r.scheme == 'https' else 80) + func(r.hostname, port) + + def _load_machines_cache(self): + """This method should fill up `_machines_cache` from scratch. + It could happen only in two cases: + 1. During class initialization + 2. When all etcd members failed""" + + self._update_machines_cache = True + + if 'srv' not in self._config and 'host' not in self._config and 'hosts' not in self._config: + raise Exception('Neither srv, hosts, host nor url are defined in etcd section of config') + + machines_cache = self._get_machines_cache_from_config() + # Can not bootstrap list of etcd-cluster members, giving up + if not machines_cache: + raise etcd.EtcdException + + # enforce resolving dns name,they might get new ips + self._update_dns_cache(self._dns_resolver.remove, machines_cache) + + # The etcd cluster could change its topology over time and depending on how we resolve the initial + # topology (list of hosts in the Patroni config or DNS records, A or SRV) we might get into the situation + # the the real topology doesn't match anymore with the topology resolved from the configuration file. + # In case if the "initial" topology is the same as before we will not override the `_machines_cache`. + ret = set(machines_cache) != set(self._initial_machines_cache) + if ret: + self._initial_machines_cache = self._machines_cache = machines_cache + + # After filling up the initial list of machines_cache we should ask etcd-cluster about actual list + self._refresh_machines_cache(True) + + self._update_machines_cache = False + return ret + + def _refresh_machines_cache(self, updating_cache=False): + if self._use_proxies: + self._machines_cache = self._get_machines_cache_from_config() + else: + try: + self._machines_cache = self.machines + except etcd.EtcdConnectionFailed: + if updating_cache: + raise etcd.EtcdException("Could not get the list of servers, " + "maybe you provided the wrong " + "host(s) to connect to?") + return + + if self._base_uri not in self._machines_cache: + self.set_base_uri(self._machines_cache[0]) + self._machines_cache_updated = time.time() + + def set_base_uri(self, value): + logger.info('Selected new etcd server %s', value) + self._base_uri = value + + +class EtcdClient(AbstractEtcdClientWithFailover): + + ERROR_CLS = EtcdError + + def __del__(self): + if self.http is not None: + try: + self.http.clear() + except (ReferenceError, TypeError, AttributeError): + pass + + def _prepare_get_members(self, etcd_nodes): + return self._prepare_common_parameters(etcd_nodes) + + def _get_members(self, base_uri, **kwargs): + response = self.http.request(self._MGET, base_uri + self.version_prefix + '/machines', **kwargs) + data = self._handle_server_response(response).data.decode('utf-8') + return [m.strip() for m in data.split(',') if m.strip()] + + def _prepare_request(self, kwargs, params=None, method=None): + kwargs['fields'] = params + if method in (self._MPOST, self._MPUT): + kwargs['encode_multipart'] = False + return self.http.request + + +class AbstractEtcd(AbstractDCS): + + def __init__(self, config, client_cls, retry_errors_cls): + super(AbstractEtcd, self).__init__(config) + self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1, + retry_exceptions=retry_errors_cls) + self._ttl = int(config.get('ttl') or 30) + self._client = self.get_etcd_client(config, client_cls) + self.__do_not_watch = False + self._has_failed = False + + def reload_config(self, config): + super(AbstractEtcd, self).reload_config(config) + self._client.reload_config(config.get(self.__class__.__name__.lower(), {})) + + def retry(self, *args, **kwargs): + retry = self._retry.copy() + kwargs['retry'] = retry + return retry(*args, **kwargs) + + def _handle_exception(self, e, name='', do_sleep=False, raise_ex=None): + if not self._has_failed: + logger.exception(name) + else: + logger.error(e) + if do_sleep: + time.sleep(1) + self._has_failed = True + if isinstance(raise_ex, Exception): + raise raise_ex + + @staticmethod + def set_socket_options(sock, socket_options): + if socket_options: + for opt in socket_options: + sock.setsockopt(*opt) + + def get_etcd_client(self, config, client_cls): + if 'proxy' in config: + config['use_proxies'] = True + config['url'] = config['proxy'] + + if 'url' in config: + r = urlparse(config['url']) + config.update({'protocol': r.scheme, 'host': r.hostname, 'port': r.port or 2379, + 'username': r.username, 'password': r.password}) + elif 'hosts' in config: + hosts = config.pop('hosts') + default_port = config.pop('port', 2379) + protocol = config.get('protocol', 'http') + + if isinstance(hosts, six.string_types): + hosts = hosts.split(',') + + config['hosts'] = [] + for value in hosts: + if isinstance(value, six.string_types): + config['hosts'].append(uri(protocol, split_host_port(value.strip(), default_port))) + elif 'host' in config: + host, port = split_host_port(config['host'], 2379) + config['host'] = host + if 'port' not in config: + config['port'] = int(port) + + if config.get('cacert'): + config['ca_cert'] = config.pop('cacert') + + if config.get('key') and config.get('cert'): + config['cert'] = (config['cert'], config['key']) + + for p in ('discovery_srv', 'srv_domain'): + if p in config: + config['srv'] = config.pop(p) + + dns_resolver = DnsCachingResolver() + + def create_connection_patched(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, socket_options=None): + host, port = address + if host.startswith('['): + host = host.strip('[]') + err = None + for af, socktype, proto, _, sa in dns_resolver.resolve(host, port): + sock = None + try: + sock = socket.socket(af, socktype, proto) + self.set_socket_options(sock, socket_options) + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + + except socket.error as e: + err = e + if sock is not None: + sock.close() + sock = None + + if err is not None: + raise err + + raise socket.error("getaddrinfo returns an empty list") + + urllib3.util.connection.create_connection = create_connection_patched + + client = None + while not client: + try: + client = client_cls(config, dns_resolver) + if 'use_proxies' in config and not client.machines: + raise etcd.EtcdException + except etcd.EtcdException: + logger.info('waiting on etcd') + time.sleep(5) + return client + + def set_ttl(self, ttl): + ttl = int(ttl) + ret = self._ttl != ttl + self._ttl = ttl + self._client.set_machines_cache_ttl(ttl*10) + return ret + + @property + def ttl(self): + return self._ttl + + def set_retry_timeout(self, retry_timeout): + self._retry.deadline = retry_timeout + self._client.set_read_timeout(retry_timeout) + + +def catch_etcd_errors(func): + def wrapper(self, *args, **kwargs): + try: + retval = func(self, *args, **kwargs) is not None + self._has_failed = False + return retval + except (RetryFailedError, etcd.EtcdException) as e: + self._handle_exception(e) + return False + except Exception as e: + self._handle_exception(e, raise_ex=self._client.ERROR_CLS('unexpected error')) + + return wrapper + + +class Etcd(AbstractEtcd): + + def __init__(self, config): + super(Etcd, self).__init__(config, EtcdClient, (etcd.EtcdLeaderElectionInProgress, EtcdRaftInternal)) + self.__do_not_watch = False + + def set_ttl(self, ttl): + self.__do_not_watch = super(Etcd, self).set_ttl(ttl) + + @staticmethod + def member(node): + return Member.from_node(node.modifiedIndex, os.path.basename(node.key), node.ttl, node.value) + + def _load_cluster(self): + cluster = None + try: + result = self.retry(self._client.read, self.client_path(''), recursive=True) + nodes = {node.key[len(result.key):].lstrip('/'): node for node in result.leaves} + + # get initialize flag + initialize = nodes.get(self._INITIALIZE) + initialize = initialize and initialize.value + + # get global dynamic configuration + config = nodes.get(self._CONFIG) + config = config and ClusterConfig.from_node(config.modifiedIndex, config.value) + + # get timeline history + history = nodes.get(self._HISTORY) + history = history and TimelineHistory.from_node(history.modifiedIndex, history.value) + + # get last leader operation + last_leader_operation = nodes.get(self._LEADER_OPTIME) + last_leader_operation = 0 if last_leader_operation is None else int(last_leader_operation.value) + + # get list of members + members = [self.member(n) for k, n in nodes.items() if k.startswith(self._MEMBERS) and k.count('/') == 1] + + # get leader + leader = nodes.get(self._LEADER) + if leader: + member = Member(-1, leader.value, None, {}) + member = ([m for m in members if m.name == leader.value] or [member])[0] + index = result.etcd_index if result.etcd_index > leader.modifiedIndex else leader.modifiedIndex + 1 + leader = Leader(index, leader.ttl, member) + + # failover key + failover = nodes.get(self._FAILOVER) + if failover: + failover = Failover.from_node(failover.modifiedIndex, failover.value) + + # get synchronization state + sync = nodes.get(self._SYNC) + sync = SyncState.from_node(sync and sync.modifiedIndex, sync and sync.value) + + cluster = Cluster(initialize, config, leader, last_leader_operation, members, failover, sync, history) + except etcd.EtcdKeyNotFound: + cluster = Cluster(None, None, None, None, [], None, None, None) + except Exception as e: + self._handle_exception(e, 'get_cluster', raise_ex=EtcdError('Etcd is not responding properly')) + self._has_failed = False + return cluster + + @catch_etcd_errors + def touch_member(self, data, permanent=False): + data = json.dumps(data, separators=(',', ':')) + return self._client.set(self.member_path, data, None if permanent else self._ttl) + + @catch_etcd_errors + def take_leader(self): + return self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl) + + def attempt_to_acquire_leader(self, permanent=False): + try: + return bool(self.retry(self._client.write, + self.leader_path, + self._name, + ttl=None if permanent else self._ttl, + prevExist=False)) + except etcd.EtcdAlreadyExist: + logger.info('Could not take out TTL lock') + except (RetryFailedError, etcd.EtcdException): + pass + return False + + @catch_etcd_errors + def set_failover_value(self, value, index=None): + return self._client.write(self.failover_path, value, prevIndex=index or 0) + + @catch_etcd_errors + def set_config_value(self, value, index=None): + return self._client.write(self.config_path, value, prevIndex=index or 0) + + @catch_etcd_errors + def _write_leader_optime(self, last_operation): + return self._client.set(self.leader_optime_path, last_operation) + + @catch_etcd_errors + def _update_leader(self): + return self.retry(self._client.write, self.leader_path, self._name, prevValue=self._name, ttl=self._ttl) + + @catch_etcd_errors + def initialize(self, create_new=True, sysid=""): + return self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new)) + + @catch_etcd_errors + def _delete_leader(self): + return self._client.delete(self.leader_path, prevValue=self._name) + + @catch_etcd_errors + def cancel_initialization(self): + return self.retry(self._client.delete, self.initialize_path) + + @catch_etcd_errors + def delete_cluster(self): + return self.retry(self._client.delete, self.client_path(''), recursive=True) + + @catch_etcd_errors + def set_history_value(self, value): + return self._client.write(self.history_path, value) + + @catch_etcd_errors + def set_sync_state_value(self, value, index=None): + return self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0) + + @catch_etcd_errors + def delete_sync_state(self, index=None): + return self.retry(self._client.delete, self.sync_path, prevIndex=index or 0) + + def watch(self, leader_index, timeout): + if self.__do_not_watch: + self.__do_not_watch = False + return True + + if leader_index: + end_time = time.time() + timeout + + while timeout >= 1: # when timeout is too small urllib3 doesn't have enough time to connect + try: + result = self._client.watch(self.leader_path, index=leader_index, timeout=timeout + 0.5) + self._has_failed = False + if result.action == 'compareAndSwap': + time.sleep(0.01) + # Synchronous work of all cluster members with etcd is less expensive + # than reestablishing http connection every time from every replica. + return True + except etcd.EtcdWatchTimedOut: + self._has_failed = False + return False + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatcherCleared): # Watch failed + self._has_failed = False + return True # leave the loop, because watch with the same parameters will fail anyway + except etcd.EtcdException as e: + self._handle_exception(e, 'watch', True) + + timeout = end_time - time.time() + + try: + return super(Etcd, self).watch(None, timeout) + finally: + self.event.clear() + + +etcd.EtcdError.error_exceptions[300] = EtcdRaftInternal diff --git a/patroni-for-openGauss/dcs/etcd3.py b/patroni-for-openGauss/dcs/etcd3.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9c0656396e6021cbe7d47b71c5f0aca9169e4f --- /dev/null +++ b/patroni-for-openGauss/dcs/etcd3.py @@ -0,0 +1,799 @@ +from __future__ import absolute_import +import base64 +import etcd +import json +import logging +import os +import six +import socket +import sys +import time +import urllib3 + +from threading import Condition, Lock, Thread + +from . import ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory +from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors +from ..exceptions import DCSError, PatroniException +from ..utils import deep_compare, enable_keepalive, iter_response_objects, RetryFailedError, USER_AGENT + +logger = logging.getLogger(__name__) + + +class Etcd3Error(DCSError): + pass + + +class UnsupportedEtcdVersion(PatroniException): + pass + + +# google.golang.org/grpc/codes +GRPCCode = type('Enum', (), {'OK': 0, 'Canceled': 1, 'Unknown': 2, 'InvalidArgument': 3, 'DeadlineExceeded': 4, + 'NotFound': 5, 'AlreadyExists': 6, 'PermissionDenied': 7, 'ResourceExhausted': 8, + 'FailedPrecondition': 9, 'Aborted': 10, 'OutOfRange': 11, 'Unimplemented': 12, + 'Internal': 13, 'Unavailable': 14, 'DataLoss': 15, 'Unauthenticated': 16}) +GRPCcodeToText = {v: k for k, v in GRPCCode.__dict__.items() if not k.startswith('__') and isinstance(v, int)} + + +class Etcd3Exception(etcd.EtcdException): + pass + + +class Etcd3ClientError(Etcd3Exception): + + def __init__(self, code=None, error=None, status=None): + if not hasattr(self, 'error'): + self.error = error and error.strip() + self.codeText = GRPCcodeToText.get(code) + self.status = status + + def __repr__(self): + return "<{0} error: '{1}', code: {2}>".format(self.__class__.__name__, self.error, self.code) + + __str__ = __repr__ + + def as_dict(self): + return {'error': self.error, 'code': self.code, 'codeText': self.codeText, 'status': self.status} + + @classmethod + def get_subclasses(cls): + for subclass in cls.__subclasses__(): + for subsubclass in subclass.get_subclasses(): + yield subsubclass + yield subclass + + +class Unknown(Etcd3ClientError): + code = GRPCCode.Unknown + + +class InvalidArgument(Etcd3ClientError): + code = GRPCCode.InvalidArgument + + +class DeadlineExceeded(Etcd3ClientError): + code = GRPCCode.DeadlineExceeded + error = "context deadline exceeded" + + +class NotFound(Etcd3ClientError): + code = GRPCCode.NotFound + + +class FailedPrecondition(Etcd3ClientError): + code = GRPCCode.FailedPrecondition + + +class Unavailable(Etcd3ClientError): + code = GRPCCode.Unavailable + + +# https://github.com/etcd-io/etcd/blob/master/etcdserver/api/v3rpc/rpctypes/error.go +class LeaseNotFound(NotFound): + error = "etcdserver: requested lease not found" + + +class UserEmpty(InvalidArgument): + error = "etcdserver: user name is empty" + + +class AuthFailed(InvalidArgument): + error = "etcdserver: authentication failed, invalid user ID or password" + + +class PermissionDenied(Etcd3ClientError): + code = GRPCCode.PermissionDenied + error = "etcdserver: permission denied" + + +class AuthNotEnabled(FailedPrecondition): + error = "etcdserver: authentication is not enabled" + + +class InvalidAuthToken(Etcd3ClientError): + code = GRPCCode.Unauthenticated + error = "etcdserver: invalid auth token" + + +errStringToClientError = {s.error: s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')} +errCodeToClientError = {s.code: s for s in Etcd3ClientError.__subclasses__()} + + +def _raise_for_data(data, status_code=None): + try: + error = data.get('error') or data.get('Error') + if isinstance(error, dict): # streaming response + status_code = error.get('http_code') + code = error['grpc_code'] + error = error['message'] + else: + code = data.get('code') or data.get('Code') + except Exception: + error = str(data) + code = GRPCCode.Unknown + err = errStringToClientError.get(error) or errCodeToClientError.get(code) or Unknown + raise err(code, error, status_code) + + +def to_bytes(v): + return v if isinstance(v, bytes) else v.encode('utf-8') + + +def prefix_range_end(v): + v = bytearray(to_bytes(v)) + for i in range(len(v) - 1, -1, -1): + if v[i] < 0xff: + v[i] += 1 + break + return bytes(v) + + +def base64_encode(v): + return base64.b64encode(to_bytes(v)).decode('utf-8') + + +def base64_decode(v): + return base64.b64decode(v).decode('utf-8') + + +def build_range_request(key, range_end=None): + fields = {'key': base64_encode(key)} + if range_end: + fields['range_end'] = base64_encode(range_end) + return fields + + +class Etcd3Client(AbstractEtcdClientWithFailover): + + ERROR_CLS = Etcd3Error + + def __init__(self, config, dns_resolver, cache_ttl=300): + self._token = None + self._cluster_version = None + self.version_prefix = '/v3beta' + super(Etcd3Client, self).__init__(config, dns_resolver, cache_ttl) + + if six.PY2: # pragma: no cover + # Old grpc-gateway sometimes sends double 'transfer-encoding: chunked' headers, + # what breaks the old (python2.7) httplib.HTTPConnection (it closes the socket). + def dedup_addheader(httpm, key, value): + prev = httpm.dict.get(key) + if prev is None: + httpm.dict[key] = value + elif key != 'transfer-encoding' or prev != value: + combined = ", ".join((prev, value)) + httpm.dict[key] = combined + + import httplib + httplib.HTTPMessage.addheader = dedup_addheader + + try: + self.authenticate() + except AuthFailed as e: + logger.fatal('Etcd3 authentication failed: %r', e) + sys.exit(1) + + def _get_headers(self): + headers = urllib3.make_headers(user_agent=USER_AGENT) + if self._token and self._cluster_version >= (3, 3, 0): + headers['authorization'] = self._token + return headers + + def _prepare_request(self, kwargs, params=None, method=None): + if params is not None: + kwargs['body'] = json.dumps(params) + kwargs['headers']['Content-Type'] = 'application/json' + return self.http.urlopen + + @staticmethod + def _handle_server_response(response): + data = response.data + try: + data = data.decode('utf-8') + data = json.loads(data) + except (TypeError, ValueError, UnicodeError) as e: + if response.status < 400: + raise etcd.EtcdException('Server response was not valid JSON: %r' % e) + if response.status < 400: + return data + _raise_for_data(data, response.status) + + def _ensure_version_prefix(self, base_uri, **kwargs): + if self.version_prefix != '/v3': + response = self.http.urlopen(self._MGET, base_uri + '/version', **kwargs) + response = self._handle_server_response(response) + + server_version_str = response['etcdserver'] + server_version = tuple(int(x) for x in server_version_str.split('.')) + cluster_version_str = response['etcdcluster'] + self._cluster_version = tuple(int(x) for x in cluster_version_str.split('.')) + + if self._cluster_version < (3, 0) or server_version < (3, 0, 4): + raise UnsupportedEtcdVersion('Detected Etcd version {0} is lower than 3.0.4'.format(server_version_str)) + + if self._cluster_version < (3, 3): + if self.version_prefix != '/v3alpha': + if self._cluster_version < (3, 1): + logger.warning('Detected Etcd version %s is lower than 3.1.0, watches are not supported', + cluster_version_str) + if self.username and self.password: + logger.warning('Detected Etcd version %s is lower than 3.3.0, authentication is not supported', + cluster_version_str) + self.version_prefix = '/v3alpha' + elif self._cluster_version < (3, 4): + self.version_prefix = '/v3beta' + else: + self.version_prefix = '/v3' + + def _prepare_get_members(self, etcd_nodes): + kwargs = self._prepare_common_parameters(etcd_nodes) + self._prepare_request(kwargs, {}) + return kwargs + + def _get_members(self, base_uri, **kwargs): + self._ensure_version_prefix(base_uri, **kwargs) + resp = self.http.urlopen(self._MPOST, base_uri + self.version_prefix + '/cluster/member/list', **kwargs) + members = self._handle_server_response(resp)['members'] + return set(url for member in members for url in member.get('clientURLs', [])) + + def call_rpc(self, method, fields, retry=None): + fields['retry'] = retry + return self.api_execute(self.version_prefix + method, self._MPOST, fields) + + def authenticate(self): + if self._cluster_version >= (3, 3) and self.username and self.password: + logger.info('Trying to authenticate on Etcd...') + old_token, self._token = self._token, None + try: + response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password}) + except AuthNotEnabled: + logger.info('Etcd authentication is not enabled') + self._token = None + except Exception: + self._token = old_token + raise + else: + self._token = response.get('token') + return old_token != self._token + + def _handle_auth_errors(func): + def wrapper(self, *args, **kwargs): + def retry(ex): + if self.username and self.password: + self.authenticate() + return func(self, *args, **kwargs) + else: + logger.fatal('Username or password not set, authentication is not possible') + raise ex + + try: + return func(self, *args, **kwargs) + except (UserEmpty, PermissionDenied) as e: # no token provided + # PermissionDenied is raised on 3.0 and 3.1 + if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) + or self._cluster_version < (3, 2)): + raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' + 'supported on version lower than 3.3.0. Cluster version: ' + '{0}'.format('.'.join(map(str, self._cluster_version)))) + return retry(e) + except InvalidAuthToken as e: + logger.error('Invalid auth token: %s', self._token) + return retry(e) + + return wrapper + + @_handle_auth_errors + def range(self, key, range_end=None, retry=None): + params = build_range_request(key, range_end) + params['serializable'] = True # For better performance. We can tolerate stale reads. + return self.call_rpc('/kv/range', params, retry) + + def prefix(self, key, retry=None): + return self.range(key, prefix_range_end(key), retry) + + def lease_grant(self, ttl, retry=None): + return self.call_rpc('/lease/grant', {'TTL': ttl}, retry)['ID'] + + def lease_keepalive(self, ID, retry=None): + return self.call_rpc('/lease/keepalive', {'ID': ID}, retry).get('result', {}).get('TTL') + + def txn(self, compare, success, retry=None): + return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded') + + @_handle_auth_errors + def put(self, key, value, lease=None, create_revision=None, mod_revision=None, retry=None): + fields = {'key': base64_encode(key), 'value': base64_encode(value)} + if lease: + fields['lease'] = lease + if create_revision is not None: + compare = {'target': 'CREATE', 'create_revision': create_revision} + elif mod_revision is not None: + compare = {'target': 'MOD', 'mod_revision': mod_revision} + else: + return self.call_rpc('/kv/put', fields, retry) + compare['key'] = fields['key'] + return self.txn(compare, {'request_put': fields}, retry) + + @_handle_auth_errors + def deleterange(self, key, range_end=None, mod_revision=None, retry=None): + fields = build_range_request(key, range_end) + if mod_revision is None: + return self.call_rpc('/kv/deleterange', fields, retry) + compare = {'target': 'MOD', 'mod_revision': mod_revision, 'key': fields['key']} + return self.txn(compare, {'request_delete_range': fields}, retry) + + def deleteprefix(self, key, retry=None): + return self.deleterange(key, prefix_range_end(key), retry=retry) + + def watchrange(self, key, range_end=None, start_revision=None, filters=None): + """returns: response object""" + params = build_range_request(key, range_end) + if start_revision is not None: + params['start_revision'] = start_revision + params['filters'] = filters or [] + kwargs = self._prepare_common_parameters(1, self.read_timeout) + request_executor = self._prepare_request(kwargs, {'create_request': params}) + kwargs.update(timeout=urllib3.Timeout(connect=kwargs['timeout']), retries=0) + return request_executor(self._MPOST, self._base_uri + self.version_prefix + '/watch', **kwargs) + + def watchprefix(self, key, start_revision=None, filters=None): + return self.watchrange(key, prefix_range_end(key), start_revision, filters) + + +class KVCache(Thread): + + def __init__(self, dcs, client): + Thread.__init__(self) + self.daemon = True + self._dcs = dcs + self._client = client + self.condition = Condition() + self._config_key = base64_encode(dcs.config_path) + self._leader_key = base64_encode(dcs.leader_path) + self._optime_key = base64_encode(dcs.leader_optime_path) + self._name = base64_encode(dcs._name) + self._is_ready = False + self._response = None + self._response_lock = Lock() + self._object_cache = {} + self._object_cache_lock = Lock() + self.start() + + def set(self, value, overwrite=False): + with self._object_cache_lock: + name = value['key'] + old_value = self._object_cache.get(name) + ret = not old_value or int(old_value['mod_revision']) < int(value['mod_revision']) + if ret or overwrite and old_value['mod_revision'] == value['mod_revision']: + self._object_cache[name] = value + return ret, old_value + + def delete(self, name, mod_revision): + with self._object_cache_lock: + old_value = self._object_cache.get(name) + ret = old_value and int(old_value['mod_revision']) < int(mod_revision) + if ret: + del self._object_cache[name] + return not old_value or ret, old_value + + def copy(self): + with self._object_cache_lock: + return [v.copy() for v in self._object_cache.values()] + + def get(self, name): + with self._object_cache_lock: + return self._object_cache.get(name) + + def _process_event(self, event): + kv = event['kv'] + key = kv['key'] + if event.get('type') == 'DELETE': + success, old_value = self.delete(key, kv['mod_revision']) + else: + success, old_value = self.set(kv, True) + + if success: + old_value = old_value and old_value.get('value') + new_value = kv.get('value') + + value_changed = old_value != new_value and \ + (key == self._leader_key or key == self._optime_key and new_value is not None or + key == self._config_key and old_value is not None and new_value is not None) + + if value_changed: + logger.debug('%s changed from %s to %s', key, old_value, new_value) + + # We also want to wake up HA loop on replicas if leader optime was updated + if value_changed and (key != self._optime_key or self.get(self._leader_key) != self._name): + self._dcs.event.set() + + def _process_message(self, message): + logger.debug('Received message: %s', message) + if 'error' in message: + _raise_for_data(message) + for event in message.get('result', {}).get('events', []): + self._process_event(event) + + @staticmethod + def _finish_response(response): + try: + response.close() + finally: + response.release_conn() + + def _do_watch(self, revision): + with self._response_lock: + self._response = None + response = self._client.watchprefix(self._dcs.cluster_prefix, revision) + with self._response_lock: + if self._response is None: + self._response = response + + if not self._response: + return self._finish_response(response) + + for message in iter_response_objects(response): + self._process_message(message) + + def _build_cache(self): + result = self._dcs.retry(self._client.prefix, self._dcs.cluster_prefix) + with self._object_cache_lock: + self._object_cache = {node['key']: node for node in result.get('kvs', [])} + with self.condition: + self._is_ready = True + self.condition.notify() + + try: + self._do_watch(result['header']['revision']) + except Exception as e: + logger.error('watchprefix failed: %r', e) + finally: + with self.condition: + self._is_ready = False + with self._response_lock: + response, self._response = self._response, None + if response: + self._finish_response(response) + + def run(self): + while True: + try: + self._build_cache() + except Exception as e: + logger.error('KVCache.run %r', e) + time.sleep(1) + + def kill_stream(self): + sock = None + with self._response_lock: + if self._response: + try: + sock = self._response.connection.sock + except Exception: + sock = None + else: + self._response = False + if sock: + try: + sock.shutdown(socket.SHUT_RDWR) + sock.close() + except Exception as e: + logger.debug('Error on socket.shutdown: %r', e) + + def is_ready(self): + """Must be called only when holding the lock on `condition`""" + return self._is_ready + + +class PatroniEtcd3Client(Etcd3Client): + + def __init__(self, *args, **kwargs): + self._kv_cache = None + super(PatroniEtcd3Client, self).__init__(*args, **kwargs) + + def configure(self, etcd3): + self._etcd3 = etcd3 + + def start_watcher(self): + if self._cluster_version >= (3, 1): + self._kv_cache = KVCache(self._etcd3, self) + + def _restart_watcher(self): + if self._kv_cache: + self._kv_cache.kill_stream() + + def set_base_uri(self, value): + super(PatroniEtcd3Client, self).set_base_uri(value) + self._restart_watcher() + + def authenticate(self): + ret = super(PatroniEtcd3Client, self).authenticate() + if ret: + self._restart_watcher() + return ret + + def _wait_cache(self, timeout): + stop_time = time.time() + timeout + while not self._kv_cache.is_ready(): + timeout = stop_time - time.time() + if timeout <= 0: + raise RetryFailedError('Exceeded retry deadline') + self._kv_cache.condition.wait(timeout) + + def get_cluster(self): + if self._kv_cache: + with self._kv_cache.condition: + self._wait_cache(self._etcd3._retry.deadline) + return self._kv_cache.copy() + else: + return self._etcd3.retry(self.prefix, self._etcd3.cluster_prefix).get('kvs', []) + + def call_rpc(self, method, fields, retry=None): + ret = super(PatroniEtcd3Client, self).call_rpc(method, fields, retry) + + if self._kv_cache: + value = delete = None + if method == '/kv/txn' and ret.get('succeeded'): + on_success = fields['success'][0] + value = on_success.get('request_put') + delete = on_success.get('request_delete_range') + elif method == '/kv/put' and ret: + value = fields + elif method == '/kv/deleterange' and ret: + delete = fields + + if value: + value['mod_revision'] = ret['header']['revision'] + self._kv_cache.set(value) + elif delete and 'range_end' not in delete: + self._kv_cache.delete(delete['key'], ret['header']['revision']) + + return ret + + +class Etcd3(AbstractEtcd): + + def __init__(self, config): + super(Etcd3, self).__init__(config, PatroniEtcd3Client, (DeadlineExceeded, Unavailable, FailedPrecondition)) + self.__do_not_watch = False + self._lease = None + self._last_lease_refresh = 0 + + self._client.configure(self) + if not self._ctl: + self._client.start_watcher() + self.create_lease() + + def set_socket_options(self, sock, socket_options): + enable_keepalive(sock, self.ttl, int(self.loop_wait + self._retry.deadline)) + + def set_ttl(self, ttl): + self.__do_not_watch = super(Etcd3, self).set_ttl(ttl) + if self.__do_not_watch: + self._lease = None + + def _do_refresh_lease(self, retry=None): + if self._lease and self._last_lease_refresh + self._loop_wait > time.time(): + return False + + if self._lease and not self._client.lease_keepalive(self._lease, retry): + self._lease = None + + ret = not self._lease + if ret: + self._lease = self._client.lease_grant(self._ttl, retry) + + self._last_lease_refresh = time.time() + return ret + + def refresh_lease(self): + try: + return self.retry(self._do_refresh_lease) + except (Etcd3ClientError, RetryFailedError): + logger.exception('refresh_lease') + raise Etcd3Error('Failed ro keepalive/grant lease') + + def create_lease(self): + while not self._lease: + try: + self.refresh_lease() + except Etcd3Error: + logger.info('waiting on etcd') + time.sleep(5) + + @property + def cluster_prefix(self): + return self.client_path('') + + @staticmethod + def member(node): + return Member.from_node(node['mod_revision'], os.path.basename(node['key']), node['lease'], node['value']) + + def _load_cluster(self): + cluster = None + try: + path_len = len(self.cluster_prefix) + + nodes = {} + for node in self._client.get_cluster(): + node['key'] = base64_decode(node['key']) + node['value'] = base64_decode(node.get('value', '')) + node['lease'] = node.get('lease') + nodes[node['key'][path_len:].lstrip('/')] = node + + # get initialize flag + initialize = nodes.get(self._INITIALIZE) + initialize = initialize and initialize['value'] + + # get global dynamic configuration + config = nodes.get(self._CONFIG) + config = config and ClusterConfig.from_node(config['mod_revision'], config['value']) + + # get timeline history + history = nodes.get(self._HISTORY) + history = history and TimelineHistory.from_node(history['mod_revision'], history['value']) + + # get last leader operation + last_leader_operation = nodes.get(self._LEADER_OPTIME) + last_leader_operation = 0 if last_leader_operation is None else int(last_leader_operation['value']) + + # get list of members + members = [self.member(n) for k, n in nodes.items() if k.startswith(self._MEMBERS) and k.count('/') == 1] + + # get leader + leader = nodes.get(self._LEADER) + if not self._ctl and leader and leader['value'] == self._name and self._lease != leader.get('lease'): + logger.warning('I am the leader but not owner of the lease') + + if leader: + member = Member(-1, leader['value'], None, {}) + member = ([m for m in members if m.name == leader['value']] or [member])[0] + leader = Leader(leader['mod_revision'], leader['lease'], member) + + # failover key + failover = nodes.get(self._FAILOVER) + if failover: + failover = Failover.from_node(failover['mod_revision'], failover['value']) + + # get synchronization state + sync = nodes.get(self._SYNC) + sync = SyncState.from_node(sync and sync['mod_revision'], sync and sync['value']) + + cluster = Cluster(initialize, config, leader, last_leader_operation, members, failover, sync, history) + except UnsupportedEtcdVersion: + raise + except Exception as e: + self._handle_exception(e, 'get_cluster', raise_ex=Etcd3Error('Etcd is not responding properly')) + self._has_failed = False + return cluster + + @catch_etcd_errors + def touch_member(self, data, permanent=False): + if not permanent: + try: + self.refresh_lease() + except Etcd3Error: + return False + + cluster = self.cluster + member = cluster and cluster.get_member(self._name, fallback_to_leader=False) + + if member and member.session == self._lease and deep_compare(data, member.data): + return True + + data = json.dumps(data, separators=(',', ':')) + try: + return self._client.put(self.member_path, data, None if permanent else self._lease) + except LeaseNotFound: + self._lease = None + logger.error('Our lease disappeared from Etcd, can not "touch_member"') + + @catch_etcd_errors + def take_leader(self): + return self.retry(self._client.put, self.leader_path, self._name, self._lease) + + @catch_etcd_errors + def _do_attempt_to_acquire_leader(self, permanent): + try: + return self.retry(self._client.put, self.leader_path, self._name, None if permanent else self._lease, 0) + except LeaseNotFound: + self._lease = None + logger.error('Our lease disappeared from Etcd. Will try to get a new one and retry attempt') + self.refresh_lease() + return self.retry(self._client.put, self.leader_path, self._name, None if permanent else self._lease, 0) + + def attempt_to_acquire_leader(self, permanent=False): + if not self._lease and not permanent: + self.refresh_lease() + + ret = self._do_attempt_to_acquire_leader(permanent) + if not ret: + logger.info('Could not take out TTL lock') + return ret + + @catch_etcd_errors + def set_failover_value(self, value, index=None): + return self._client.put(self.failover_path, value, mod_revision=index) + + @catch_etcd_errors + def set_config_value(self, value, index=None): + return self._client.put(self.config_path, value, mod_revision=index) + + @catch_etcd_errors + def _write_leader_optime(self, last_operation): + return self._client.put(self.leader_optime_path, last_operation) + + @catch_etcd_errors + def _update_leader(self): + if not self._lease: + self.refresh_lease() + elif self.retry(self._client.lease_keepalive, self._lease): + self._last_lease_refresh = time.time() + + if self._lease: + cluster = self.cluster + leader_lease = cluster and isinstance(cluster.leader, Leader) and cluster.leader.session + if leader_lease != self._lease: + self.take_leader() + return bool(self._lease) + + @catch_etcd_errors + def initialize(self, create_new=True, sysid=""): + return self.retry(self._client.put, self.initialize_path, sysid, None, 0 if create_new else None) + + @catch_etcd_errors + def _delete_leader(self): + cluster = self.cluster + if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name: + return self._client.deleterange(self.leader_path, mod_revision=cluster.leader.index) + + @catch_etcd_errors + def cancel_initialization(self): + return self.retry(self._client.deleterange, self.initialize_path) + + @catch_etcd_errors + def delete_cluster(self): + return self.retry(self._client.deleteprefix, self.cluster_prefix) + + @catch_etcd_errors + def set_history_value(self, value): + return self._client.put(self.history_path, value) + + @catch_etcd_errors + def set_sync_state_value(self, value, index=None): + return self.retry(self._client.put, self.sync_path, value, mod_revision=index) + + @catch_etcd_errors + def delete_sync_state(self, index=None): + return self.retry(self._client.deleterange, self.sync_path, mod_revision=index) + + def watch(self, leader_index, timeout): + if self.__do_not_watch: + self.__do_not_watch = False + return True + + try: + return super(Etcd3, self).watch(None, timeout) + finally: + self.event.clear() diff --git a/patroni-for-openGauss/dcs/exhibitor.py b/patroni-for-openGauss/dcs/exhibitor.py new file mode 100644 index 0000000000000000000000000000000000000000..70066d6509eb2365cfee68d681bad62d3d973c85 --- /dev/null +++ b/patroni-for-openGauss/dcs/exhibitor.py @@ -0,0 +1,74 @@ +import json +import logging +import random +import time + +from patroni.dcs.zookeeper import ZooKeeper +from patroni.request import get as requests_get +from patroni.utils import uri + +logger = logging.getLogger(__name__) + + +class ExhibitorEnsembleProvider(object): + + TIMEOUT = 3.1 + + def __init__(self, hosts, port, uri_path='/exhibitor/v1/cluster/list', poll_interval=300): + self._exhibitor_port = port + self._uri_path = uri_path + self._poll_interval = poll_interval + self._exhibitors = hosts + self._master_exhibitors = hosts + self._zookeeper_hosts = '' + self._next_poll = None + while not self.poll(): + logger.info('waiting on exhibitor') + time.sleep(5) + + def poll(self): + if self._next_poll and self._next_poll > time.time(): + return False + + json = self._query_exhibitors(self._exhibitors) + if not json: + json = self._query_exhibitors(self._master_exhibitors) + + if isinstance(json, dict) and 'servers' in json and 'port' in json: + self._next_poll = time.time() + self._poll_interval + zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(json['servers'])]) + if self._zookeeper_hosts != zookeeper_hosts: + logger.info('ZooKeeper connection string has changed: %s => %s', self._zookeeper_hosts, zookeeper_hosts) + self._zookeeper_hosts = zookeeper_hosts + self._exhibitors = json['servers'] + return True + return False + + def _query_exhibitors(self, exhibitors): + random.shuffle(exhibitors) + for host in exhibitors: + try: + response = requests_get(uri('http', (host, self._exhibitor_port), self._uri_path), timeout=self.TIMEOUT) + return json.loads(response.data.decode('utf-8')) + except Exception: + logging.debug('Request to %s failed', host) + return None + + @property + def zookeeper_hosts(self): + return self._zookeeper_hosts + + +class Exhibitor(ZooKeeper): + + def __init__(self, config): + interval = config.get('poll_interval', 300) + self._ensemble_provider = ExhibitorEnsembleProvider(config['hosts'], config['port'], poll_interval=interval) + config = config.copy() + config['hosts'] = self._ensemble_provider.zookeeper_hosts + super(Exhibitor, self).__init__(config) + + def _load_cluster(self): + if self._ensemble_provider.poll(): + self._client.set_hosts(self._ensemble_provider.zookeeper_hosts) + return super(Exhibitor, self)._load_cluster() diff --git a/patroni-for-openGauss/dcs/kubernetes.py b/patroni-for-openGauss/dcs/kubernetes.py new file mode 100644 index 0000000000000000000000000000000000000000..abd98a5477957ccf2bf3b429d29bc1b33429f90e --- /dev/null +++ b/patroni-for-openGauss/dcs/kubernetes.py @@ -0,0 +1,1063 @@ +import datetime +import functools +import json +import logging +import os +import random +import socket +import six +import sys +import time +import urllib3 +import yaml + +from urllib3 import Timeout +from urllib3.exceptions import HTTPError +from six.moves.http_client import HTTPException +from threading import Condition, Lock, Thread + +from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState, TimelineHistory +from ..exceptions import DCSError +from ..utils import deep_compare, iter_response_objects, keepalive_socket_options,\ + Retry, RetryFailedError, tzutc, uri, USER_AGENT + +logger = logging.getLogger(__name__) + +KUBE_CONFIG_DEFAULT_LOCATION = os.environ.get('KUBECONFIG', '~/.kube/config') +SERVICE_HOST_ENV_NAME = 'KUBERNETES_SERVICE_HOST' +SERVICE_PORT_ENV_NAME = 'KUBERNETES_SERVICE_PORT' +SERVICE_TOKEN_FILENAME = '/var/run/secrets/kubernetes.io/serviceaccount/token' +SERVICE_CERT_FILENAME = '/var/run/secrets/kubernetes.io/serviceaccount/ca.crt' + + +class KubernetesError(DCSError): + pass + + +# this function does the same mapping of snake_case => camelCase for > 97% of cases as autogenerated swagger code +def to_camel_case(value): + reserved = {'api', 'apiv3', 'cidr', 'cpu', 'csi', 'id', 'io', 'ip', 'ipc', 'pid', 'tls', 'uri', 'url', 'uuid'} + words = value.split('_') + return words[0] + ''.join(w.upper() if w in reserved else w.title() for w in words[1:]) + + +class K8sConfig(object): + + class ConfigException(Exception): + pass + + def __init__(self): + self.pool_config = {'maxsize': 10, 'num_pools': 10} # configuration for urllib3.PoolManager + self._make_headers() + + def _make_headers(self, token=None, **kwargs): + self._headers = urllib3.make_headers(user_agent=USER_AGENT, **kwargs) + if token: + self._headers['authorization'] = 'Bearer ' + token + + def load_incluster_config(self): + if SERVICE_HOST_ENV_NAME not in os.environ or SERVICE_PORT_ENV_NAME not in os.environ: + raise self.ConfigException('Service host/port is not set.') + if not os.environ[SERVICE_HOST_ENV_NAME] or not os.environ[SERVICE_PORT_ENV_NAME]: + raise self.ConfigException('Service host/port is set but empty.') + if not os.path.isfile(SERVICE_CERT_FILENAME): + raise self.ConfigException('Service certificate file does not exists.') + with open(SERVICE_CERT_FILENAME) as f: + if not f.read(): + raise self.ConfigException('Cert file exists but empty.') + if not os.path.isfile(SERVICE_TOKEN_FILENAME): + raise self.ConfigException('Service token file does not exists.') + with open(SERVICE_TOKEN_FILENAME) as f: + token = f.read() + if not token: + raise self.ConfigException('Token file exists but empty.') + self._make_headers(token=token) + self.pool_config['ca_certs'] = SERVICE_CERT_FILENAME + self._server = uri('https', (os.environ[SERVICE_HOST_ENV_NAME], os.environ[SERVICE_PORT_ENV_NAME])) + + @staticmethod + def _get_by_name(config, section, name): + for c in config[section + 's']: + if c['name'] == name: + return c[section] + + def load_kube_config(self, context=None): + with open(os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION)) as f: + config = yaml.safe_load(f) + + context = self._get_by_name(config, 'context', context or config['current-context']) + cluster = self._get_by_name(config, 'cluster', context['cluster']) + user = self._get_by_name(config, 'user', context['user']) + + self._server = cluster['server'].rstrip('/') + if self._server.startswith('https'): + self.pool_config.update({v: user[k] for k, v in {'client-certificate': 'cert_file', + 'client-key': 'key_file'}.items() if k in user}) + if 'certificate-authority' in cluster: + self.pool_config['ca_certs'] = cluster['certificate-authority'] + self.pool_config['cert_reqs'] = 'CERT_NONE' if cluster.get('insecure-skip-tls-verify') else 'CERT_REQUIRED' + if user.get('token'): + self._make_headers(token=user['token']) + elif 'username' in user and 'password' in user: + self._headers = self._make_headers(basic_auth=':'.join((user['username'], user['password']))) + + @property + def server(self): + return self._server + + @property + def headers(self): + return self._headers.copy() + + +class K8sObject(object): + + def __init__(self, kwargs): + self._dict = {k: self._wrap(k, v) for k, v in kwargs.items()} + + def get(self, name, default=None): + return self._dict.get(name, default) + + def __getattr__(self, name): + return self.get(to_camel_case(name)) + + @classmethod + def _wrap(cls, parent, value): + if isinstance(value, dict): + # we know that `annotations` and `labels` are dicts and therefore don't want to convert them into K8sObject + return value if parent in {'annotations', 'labels'} and \ + all(isinstance(v, six.string_types) for v in value.values()) else cls(value) + elif isinstance(value, list): + return [cls._wrap(None, v) for v in value] + else: + return value + + def to_dict(self): + return self._dict + + def __repr__(self): + return json.dumps(self, indent=4, default=lambda o: o.to_dict()) + + +class K8sException(Exception): + pass + + +class K8sConnectionFailed(K8sException): + pass + + +class K8sClient(object): + + class rest(object): + + class ApiException(Exception): + def __init__(self, status=None, reason=None, http_resp=None): + self.status = http_resp.status if http_resp else status + self.reason = http_resp.reason if http_resp else reason + self.body = http_resp.data if http_resp else None + self.headers = http_resp.getheaders() if http_resp else None + + def __str__(self): + error_message = "({0})\nReason: {1}\n".format(self.status, self.reason) + if self.headers: + error_message += "HTTP response headers: {0}\n".format(self.headers) + if self.body: + error_message += "HTTP response body: {0}\n".format(self.body) + return error_message + + class ApiClient(object): + + _API_URL_PREFIX = '/api/v1/namespaces/' + + def __init__(self, bypass_api_service=False): + self._bypass_api_service = bypass_api_service + self.pool_manager = urllib3.PoolManager(**k8s_config.pool_config) + self._base_uri = k8s_config.server + self._api_servers_cache = [k8s_config.server] + self._api_servers_cache_updated = 0 + self.set_api_servers_cache_ttl(10) + self.set_read_timeout(10) + try: + self._load_api_servers_cache() + except K8sException: + pass + + def set_read_timeout(self, timeout): + self._read_timeout = timeout + + def set_api_servers_cache_ttl(self, ttl): + self._api_servers_cache_ttl = ttl - 0.5 + + def set_base_uri(self, value): + logger.info('Selected new K8s API server endpoint %s', value) + # We will connect by IP of the master node which is not listed as alternative name + self.pool_manager.connection_pool_kw['assert_hostname'] = False + self._base_uri = value + + @staticmethod + def _handle_server_response(response, _preload_content): + if response.status not in range(200, 206): + raise k8s_client.rest.ApiException(http_resp=response) + return K8sObject(json.loads(response.data.decode('utf-8'))) if _preload_content else response + + @staticmethod + def _make_headers(headers): + ret = k8s_config.headers + ret.update(headers or {}) + return ret + + @property + def api_servers_cache(self): + base_uri, cache = self._base_uri, self._api_servers_cache + return ([base_uri] if base_uri in cache else []) + [machine for machine in cache if machine != base_uri] + + def _get_api_servers(self, api_servers_cache): + _, per_node_timeout, per_node_retries = self._calculate_timeouts(len(api_servers_cache)) + kwargs = {'headers': self._make_headers({}), 'preload_content': True, 'retries': per_node_retries, + 'timeout': urllib3.Timeout(connect=max(1, per_node_timeout/2.0), total=per_node_timeout)} + path = self._API_URL_PREFIX + 'default/endpoints/kubernetes' + for base_uri in api_servers_cache: + try: + response = self.pool_manager.request('GET', base_uri + path, **kwargs) + endpoint = self._handle_server_response(response, True) + for subset in endpoint.subsets: + for port in subset.ports: + if port.name == 'https' and port.protocol == 'TCP': + addresses = [uri('https', (a.ip, port.port)) for a in subset.addresses] + if addresses: + random.shuffle(addresses) + return addresses + except Exception as e: + if isinstance(e, k8s_client.rest.ApiException) and e.status == 403: + raise + self.pool_manager.clear() + logger.error('Failed to get "kubernetes" endpoint from %s: %r', base_uri, e) + raise K8sConnectionFailed('No more K8s API server nodes in the cluster') + + def _refresh_api_servers_cache(self, updating_cache=False): + if self._bypass_api_service: + try: + api_servers_cache = [k8s_config.server] if updating_cache else self.api_servers_cache + self._api_servers_cache = self._get_api_servers(api_servers_cache) + if updating_cache: + self.pool_manager.clear() + except k8s_client.rest.ApiException: # 403 Permission denied + logger.warning("Kubernetes RBAC doesn't allow GET access to the 'kubernetes' " + "endpoint in the 'default' namespace. Disabling 'bypass_api_service'.") + self._bypass_api_service = False + self._api_servers_cache = [k8s_config.server] + if not updating_cache: + self.pool_manager.clear() + except K8sConnectionFailed: + if updating_cache: + raise K8sException("Could not get the list of K8s API server nodes") + return + else: + self._api_servers_cache = [k8s_config.server] + + if self._base_uri not in self._api_servers_cache: + self.set_base_uri(self._api_servers_cache[0]) + self._api_servers_cache_updated = time.time() + + def refresh_api_servers_cache(self): + if self._bypass_api_service and time.time() - self._api_servers_cache_updated > self._api_servers_cache_ttl: + self._refresh_api_servers_cache() + + def _load_api_servers_cache(self): + self._update_api_servers_cache = True + self._refresh_api_servers_cache(True) + self._update_api_servers_cache = False + + def _calculate_timeouts(self, api_servers, timeout=None): + """Calculate a request timeout and number of retries per single K8s API server node. + In case if the timeout per node is too small (less than one second) we will reduce the number of nodes. + For the cluster with only one API server node we will try to do 1 retry. + No retries for clusters with 2 or more API server nodes. We better rely on switching to a different node.""" + + per_node_timeout = timeout = float(timeout or self._read_timeout) + + max_retries = 3 - min(api_servers, 2) + per_node_retries = 1 + min_timeout = 1.0 + + while api_servers > 0: + per_node_timeout = float(timeout) / api_servers + if per_node_timeout >= min_timeout: + # for small clusters we will try to do more than one try on every node + while per_node_retries < max_retries and per_node_timeout / (per_node_retries + 1) >= min_timeout: + per_node_retries += 1 + per_node_timeout /= per_node_retries + break + # if the timeout per one node is to small try to reduce number of nodes + api_servers -= 1 + max_retries = 1 + + return api_servers, per_node_timeout, per_node_retries - 1 + + def _do_http_request(self, retry, api_servers_cache, method, path, **kwargs): + some_request_failed = False + for i, base_uri in enumerate(api_servers_cache): + if i > 0: + logger.info('Retrying on %s', base_uri) + try: + response = self.pool_manager.request(method, base_uri + path, **kwargs) + if some_request_failed: + self.set_base_uri(base_uri) + self._refresh_api_servers_cache() + return response + except (HTTPError, HTTPException, socket.error, socket.timeout) as e: + self.pool_manager.clear() + if not retry: + # switch to the next node if request failed and retry is not allowed + if i + 1 < len(api_servers_cache): + self.set_base_uri(api_servers_cache[i + 1]) + raise K8sException('{0} {1} request failed'.format(method, path)) + logger.error('Request to server %s failed: %r', base_uri, e) + some_request_failed = True + + raise K8sConnectionFailed('No more API server nodes in the cluster') + + def request(self, retry, method, path, timeout=None, **kwargs): + if self._update_api_servers_cache: + self._load_api_servers_cache() + + api_servers_cache = self.api_servers_cache + api_servers = len(api_servers_cache) + + if timeout: + if isinstance(timeout, six.integer_types + (float,)): + timeout = urllib3.Timeout(total=timeout) + elif isinstance(timeout, tuple) and len(timeout) == 2: + timeout = urllib3.Timeout(connect=timeout[0], read=timeout[1]) + retries = 0 + else: + _, timeout, retries = self._calculate_timeouts(api_servers) + timeout = urllib3.Timeout(connect=max(1, timeout/2.0), total=timeout) + kwargs.update(retries=retries, timeout=timeout) + + while True: + try: + return self._do_http_request(retry, api_servers_cache, method, path, **kwargs) + except K8sConnectionFailed as ex: + try: + self._load_api_servers_cache() + api_servers_cache = self.api_servers_cache + api_servers = len(api_servers) + except Exception as e: + logger.debug('Failed to update list of K8s master nodes: %r', e) + + sleeptime = retry.sleeptime + remaining_time = retry.stoptime - sleeptime - time.time() + nodes, timeout, retries = self._calculate_timeouts(api_servers, remaining_time) + if nodes == 0: + self._update_api_servers_cache = True + raise ex + retry.sleep_func(sleeptime) + retry.update_delay() + # We still have some time left. Partially reduce `api_servers_cache` and retry request + kwargs.update(timeout=urllib3.Timeout(connect=max(1, timeout/2.0), total=timeout), retries=retries) + api_servers_cache = api_servers_cache[:nodes] + + def call_api(self, method, path, headers=None, body=None, _retry=None, + _preload_content=True, _request_timeout=None, **kwargs): + headers = self._make_headers(headers) + fields = {to_camel_case(k): v for k, v in kwargs.items()} # resource_version => resourceVersion + body = json.dumps(body, default=lambda o: o.to_dict()) if body is not None else None + + response = self.request(_retry, method, self._API_URL_PREFIX + path, headers=headers, fields=fields, + body=body, preload_content=_preload_content, timeout=_request_timeout) + + return self._handle_server_response(response, _preload_content) + + class CoreV1Api(object): + + def __init__(self, api_client=None): + self._api_client = api_client or k8s_client.ApiClient() + + def __getattr__(self, func): # `func` name pattern: (action)_namespaced_(kind) + action, kind = func.split('_namespaced_') # (read|list|create|patch|replace|delete|delete_collection) + kind = kind.replace('_', '') + ('s' * int(kind[-1] != 's')) # plural, single word + + def wrapper(*args, **kwargs): + method = {'read': 'GET', 'list': 'GET', 'create': 'POST', + 'replace': 'PUT'}.get(action, action.split('_')[0]).upper() + + if action == 'create' or len(args) == 1: # namespace is a first argument and name in not in arguments + path = '/'.join([args[0], kind]) + else: # name, namespace followed by optional body + path = '/'.join([args[1], kind, args[0]]) + + headers = {'Content-Type': 'application/strategic-merge-patch+json'} if action == 'patch' else {} + + if len(args) == 3: # name, namespace, body + body = args[2] + elif action == 'create': # namespace, body + body = args[1] + elif action == 'delete': # name, namespace + body = kwargs.pop('body', None) + else: + body = None + + return self._api_client.call_api(method, path, headers, body, **kwargs) + return wrapper + + class _K8sObjectTemplate(K8sObject): + """The template for objects which we create locally, e.g. k8s_client.V1ObjectMeta & co""" + def __init__(self, **kwargs): + self._dict = {to_camel_case(k): v for k, v in kwargs.items()} + + def __init__(self): + self.__cls_cache = {} + self.__cls_lock = Lock() + + def __getattr__(self, name): + with self.__cls_lock: + if name not in self.__cls_cache: + self.__cls_cache[name] = type(name, (self._K8sObjectTemplate,), {}) + return self.__cls_cache[name] + + +k8s_client = K8sClient() +k8s_config = K8sConfig() + + +class KubernetesRetriableException(k8s_client.rest.ApiException): + + def __init__(self, orig): + super(KubernetesRetriableException, self).__init__(orig.status, orig.reason) + self.body = orig.body + self.headers = orig.headers + + @property + def sleeptime(self): + try: + return int(self.headers['retry-after']) + except Exception: + return None + + +class CoreV1ApiProxy(object): + + def __init__(self, use_endpoints=False, bypass_api_service=False): + self._api_client = k8s_client.ApiClient(bypass_api_service) + self._core_v1_api = k8s_client.CoreV1Api(self._api_client) + self._use_endpoints = bool(use_endpoints) + + def configure_timeouts(self, loop_wait, retry_timeout, ttl): + # Normally every loop_wait seconds we should have receive something from the socket. + # If we didn't received anything after the loop_wait + retry_timeout it is a time + # to start worrying (send keepalive messages). Finally, the connection should be + # considered as dead if we received nothing from the socket after the ttl seconds. + self._api_client.pool_manager.connection_pool_kw['socket_options'] = \ + list(keepalive_socket_options(ttl, int(loop_wait + retry_timeout))) + self._api_client.set_read_timeout(retry_timeout) + self._api_client.set_api_servers_cache_ttl(loop_wait) + + def refresh_api_servers_cache(self): + self._api_client.refresh_api_servers_cache() + + def __getattr__(self, func): + if func.endswith('_kind'): + func = func[:-4] + ('endpoints' if self._use_endpoints else 'config_map') + + def wrapper(*args, **kwargs): + try: + return getattr(self._core_v1_api, func)(*args, **kwargs) + except k8s_client.rest.ApiException as e: + if e.status in (500, 503, 504) or e.headers and 'retry-after' in e.headers: # XXX + raise KubernetesRetriableException(e) + raise + return wrapper + + @property + def use_endpoints(self): + return self._use_endpoints + + +def catch_kubernetes_errors(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except k8s_client.rest.ApiException as e: + if e.status == 403: + logger.exception('Permission denied') + elif e.status != 409: # Object exists or conflict in resource_version + logger.exception('Unexpected error from Kubernetes API') + return False + except (RetryFailedError, K8sException): + return False + return wrapper + + +class ObjectCache(Thread): + + def __init__(self, dcs, func, retry, condition, name=None): + Thread.__init__(self) + self.daemon = True + self._dcs = dcs + self._func = func + self._retry = retry + self._condition = condition + self._name = name # name of this pod + self._is_ready = False + self._object_cache = {} + self._object_cache_lock = Lock() + self._annotations_map = {self._dcs.leader_path: self._dcs._LEADER, self._dcs.config_path: self._dcs._CONFIG} + self.start() + + def _list(self): + try: + return self._func(_retry=self._retry.copy()) + except Exception: + time.sleep(1) + raise + + def _watch(self, resource_version): + return self._func(_request_timeout=(self._retry.deadline, Timeout.DEFAULT_TIMEOUT), + _preload_content=False, watch=True, resource_version=resource_version) + + def set(self, name, value): + with self._object_cache_lock: + old_value = self._object_cache.get(name) + ret = not old_value or int(old_value.metadata.resource_version) < int(value.metadata.resource_version) + if ret: + self._object_cache[name] = value + return ret, old_value + + def delete(self, name, resource_version): + with self._object_cache_lock: + old_value = self._object_cache.get(name) + ret = old_value and int(old_value.metadata.resource_version) < int(resource_version) + if ret: + del self._object_cache[name] + return not old_value or ret, old_value + + def copy(self): + with self._object_cache_lock: + return self._object_cache.copy() + + def get(self, name): + with self._object_cache_lock: + return self._object_cache.get(name) + + def _build_cache(self): + objects = self._list() + return_type = 'V1' + objects.kind[:-4] + with self._object_cache_lock: + self._object_cache = {item.metadata.name: item for item in objects.items} + with self._condition: + self._is_ready = True + self._condition.notify() + + response = self._watch(objects.metadata.resource_version) + try: + for event in iter_response_objects(response): + obj = event['object'] + if obj.get('code') == 410: + break + + ev_type = event['type'] + name = obj['metadata']['name'] + + if ev_type in ('ADDED', 'MODIFIED'): + obj = K8sObject(obj) + success, old_value = self.set(name, obj) + if success: + new_value = (obj.metadata.annotations or {}).get(self._annotations_map.get(name)) + elif ev_type == 'DELETED': + success, old_value = self.delete(name, obj['metadata']['resourceVersion']) + new_value = None + else: + logger.warning('Unexpected event type: %s', ev_type) + continue + + if success and return_type != 'V1Pod': + if old_value: + old_value = (old_value.metadata.annotations or {}).get(self._annotations_map.get(name)) + + value_changed = old_value != new_value and \ + (name != self._dcs.config_path or old_value is not None and new_value is not None) + + if value_changed: + logger.debug('%s changed from %s to %s', name, old_value, new_value) + + # Do not wake up HA loop if we run as leader and received leader object update event + if value_changed or name == self._dcs.leader_path and self._name != new_value: + self._dcs.event.set() + finally: + with self._condition: + self._is_ready = False + response.close() + response.release_conn() + + def run(self): + while True: + try: + self._build_cache() + except Exception as e: + with self._condition: + self._is_ready = False + logger.error('ObjectCache.run %r', e) + + def is_ready(self): + """Must be called only when holding the lock on `_condition`""" + return self._is_ready + + +class Kubernetes(AbstractDCS): + + def __init__(self, config): + self._labels = config['labels'] + self._labels[config.get('scope_label', 'cluster-name')] = config['scope'] + self._label_selector = ','.join('{0}={1}'.format(k, v) for k, v in self._labels.items()) + self._namespace = config.get('namespace') or 'default' + self._role_label = config.get('role_label', 'role') + config['namespace'] = '' + super(Kubernetes, self).__init__(config) + self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1, + retry_exceptions=KubernetesRetriableException) + self._ttl = None + try: + k8s_config.load_incluster_config() + except k8s_config.ConfigException: + k8s_config.load_kube_config(context=config.get('context', 'local')) + + self.__my_pod = None + self.__ips = [] if config.get('patronictl') else [config.get('pod_ip')] + self.__ports = [] + for p in config.get('ports', [{}]): + port = {'port': int(p.get('port', '5432'))} + port.update({n: p[n] for n in ('name', 'protocol') if p.get(n)}) + self.__ports.append(k8s_client.V1EndpointPort(**port)) + + bypass_api_service = not config.get('patronictl') and config.get('bypass_api_service') + self._api = CoreV1ApiProxy(config.get('use_endpoints'), bypass_api_service) + self._should_create_config_service = self._api.use_endpoints + self.reload_config(config) + # leader_observed_record, leader_resource_version, and leader_observed_time are used only for leader race! + self._leader_observed_record = {} + self._leader_observed_time = None + self._leader_resource_version = None + self.__do_not_watch = False + + self._condition = Condition() + + pods_func = functools.partial(self._api.list_namespaced_pod, self._namespace, + label_selector=self._label_selector) + self._pods = ObjectCache(self, pods_func, self._retry, self._condition) + + kinds_func = functools.partial(self._api.list_namespaced_kind, self._namespace, + label_selector=self._label_selector) + self._kinds = ObjectCache(self, kinds_func, self._retry, self._condition, self._name) + + def retry(self, *args, **kwargs): + retry = self._retry.copy() + kwargs['_retry'] = retry + return retry(*args, **kwargs) + + def client_path(self, path): + return super(Kubernetes, self).client_path(path)[1:].replace('/', '-') + + @property + def leader_path(self): + return self._base_path[1:] if self._api.use_endpoints else super(Kubernetes, self).leader_path + + def set_ttl(self, ttl): + ttl = int(ttl) + self.__do_not_watch = self._ttl != ttl + self._ttl = ttl + + @property + def ttl(self): + return self._ttl + + def set_retry_timeout(self, retry_timeout): + self._retry.deadline = retry_timeout + + def reload_config(self, config): + super(Kubernetes, self).reload_config(config) + self._api.configure_timeouts(self.loop_wait, self._retry.deadline, self.ttl) + + @staticmethod + def member(pod): + annotations = pod.metadata.annotations or {} + member = Member.from_node(pod.metadata.resource_version, pod.metadata.name, None, annotations.get('status', '')) + member.data['pod_labels'] = pod.metadata.labels + return member + + def _wait_caches(self, stop_time): + while not (self._pods.is_ready() and self._kinds.is_ready()): + timeout = stop_time - time.time() + if timeout <= 0: + raise RetryFailedError('Exceeded retry deadline') + self._condition.wait(timeout) + + def _load_cluster(self): + stop_time = time.time() + self._retry.deadline + self._api.refresh_api_servers_cache() + try: + with self._condition: + self._wait_caches(stop_time) + + members = [self.member(pod) for pod in self._pods.copy().values()] + nodes = self._kinds.copy() + + config = nodes.get(self.config_path) + metadata = config and config.metadata + annotations = metadata and metadata.annotations or {} + + # get initialize flag + initialize = annotations.get(self._INITIALIZE) + + # get global dynamic configuration + config = ClusterConfig.from_node(metadata and metadata.resource_version, + annotations.get(self._CONFIG) or '{}', + metadata.resource_version if self._CONFIG in annotations else 0) + + # get timeline history + history = TimelineHistory.from_node(metadata and metadata.resource_version, + annotations.get(self._HISTORY) or '[]') + + leader = nodes.get(self.leader_path) + metadata = leader and leader.metadata + self._leader_resource_version = metadata.resource_version if metadata else None + annotations = metadata and metadata.annotations or {} + + # get last leader operation + last_leader_operation = annotations.get(self._OPTIME) + last_leader_operation = 0 if last_leader_operation is None else int(last_leader_operation) + + # get leader + leader_record = {n: annotations.get(n) for n in (self._LEADER, 'acquireTime', + 'ttl', 'renewTime', 'transitions') if n in annotations} + if (leader_record or self._leader_observed_record) and leader_record != self._leader_observed_record: + self._leader_observed_record = leader_record + self._leader_observed_time = time.time() + + leader = leader_record.get(self._LEADER) + try: + ttl = int(leader_record.get('ttl')) or self._ttl + except (TypeError, ValueError): + ttl = self._ttl + + if not metadata or not self._leader_observed_time or self._leader_observed_time + ttl < time.time(): + leader = None + + if metadata: + member = Member(-1, leader, None, {}) + member = ([m for m in members if m.name == leader] or [member])[0] + leader = Leader(metadata.resource_version, None, member) + + # failover key + failover = nodes.get(self.failover_path) + metadata = failover and failover.metadata + failover = Failover.from_node(metadata and metadata.resource_version, + metadata and (metadata.annotations or {}).copy()) + + # get synchronization state + sync = nodes.get(self.sync_path) + metadata = sync and sync.metadata + sync = SyncState.from_node(metadata and metadata.resource_version, metadata and metadata.annotations) + + return Cluster(initialize, config, leader, last_leader_operation, members, failover, sync, history) + except Exception: + logger.exception('get_cluster') + raise KubernetesError('Kubernetes API is not responding properly') + + @staticmethod + def compare_ports(p1, p2): + return p1.name == p2.name and p1.port == p2.port and (p1.protocol or 'TCP') == (p2.protocol or 'TCP') + + @staticmethod + def subsets_changed(last_observed_subsets, ip, ports): + """ + >>> Kubernetes.subsets_changed([], None, []) + True + >>> ip = '1.2.3.4' + >>> a = [k8s_client.V1EndpointAddress(ip=ip)] + >>> s = [k8s_client.V1EndpointSubset(addresses=a)] + >>> Kubernetes.subsets_changed(s, '1.2.3.5', []) + True + >>> s = [k8s_client.V1EndpointSubset(addresses=a, ports=[k8s_client.V1EndpointPort(protocol='TCP', port=1)])] + >>> Kubernetes.subsets_changed(s, '1.2.3.4', [k8s_client.V1EndpointPort(port=5432)]) + True + >>> p1 = k8s_client.V1EndpointPort(name='port1', port=1) + >>> p2 = k8s_client.V1EndpointPort(name='port2', port=2) + >>> p3 = k8s_client.V1EndpointPort(name='port3', port=3) + >>> s = [k8s_client.V1EndpointSubset(addresses=a, ports=[p1, p2])] + >>> Kubernetes.subsets_changed(s, ip, [p2, p3]) + True + >>> s2 = [k8s_client.V1EndpointSubset(addresses=a, ports=[p2, p1])] + >>> Kubernetes.subsets_changed(s, ip, [p2, p1]) + False + """ + + if len(last_observed_subsets) != 1: + return True + if len(last_observed_subsets[0].addresses or []) != 1 or \ + last_observed_subsets[0].addresses[0].ip != ip or \ + len(last_observed_subsets[0].ports) != len(ports): + return True + if len(ports) == 1: + return not Kubernetes.compare_ports(last_observed_subsets[0].ports[0], ports[0]) + observed_ports = {p.name: p for p in last_observed_subsets[0].ports} + for p in ports: + if p.name not in observed_ports or not Kubernetes.compare_ports(p, observed_ports.pop(p.name)): + return True + return False + + def __target_ref(self, leader_ip, latest_subsets, pod): + # we want to re-use existing target_ref if possible + for subset in latest_subsets: + for address in subset.addresses or []: + if address.ip == leader_ip and address.target_ref and address.target_ref.name == self._name: + return address.target_ref + return k8s_client.V1ObjectReference(kind='Pod', uid=pod.metadata.uid, namespace=self._namespace, + name=self._name, resource_version=pod.metadata.resource_version) + + def _map_subsets(self, endpoints, ips): + leader = self._kinds.get(self.leader_path) + latest_subsets = leader and leader.subsets or [] + if not ips: + # We want to have subsets empty + if latest_subsets: + endpoints['subsets'] = [] + return + + pod = self._pods.get(self._name) + leader_ip = ips[0] or pod and pod.status.pod_ip + # don't touch subsets if our (leader) ip is unknown or subsets is valid + if leader_ip and self.subsets_changed(latest_subsets, leader_ip, self.__ports): + kwargs = {'hostname': pod.spec.hostname, 'node_name': pod.spec.node_name, + 'target_ref': self.__target_ref(leader_ip, latest_subsets, pod)} if pod else {} + address = k8s_client.V1EndpointAddress(ip=leader_ip, **kwargs) + endpoints['subsets'] = [k8s_client.V1EndpointSubset(addresses=[address], ports=self.__ports)] + + def _patch_or_create(self, name, annotations, resource_version=None, patch=False, retry=None, ips=None): + metadata = {'namespace': self._namespace, 'name': name, 'labels': self._labels, 'annotations': annotations} + if patch or resource_version: + if resource_version is not None: + metadata['resource_version'] = resource_version + func = functools.partial(self._api.patch_namespaced_kind, name) + else: + func = functools.partial(self._api.create_namespaced_kind) + # skip annotations with null values + metadata['annotations'] = {k: v for k, v in metadata['annotations'].items() if v is not None} + + metadata = k8s_client.V1ObjectMeta(**metadata) + if ips is not None and self._api.use_endpoints: + endpoints = {'metadata': metadata} + self._map_subsets(endpoints, ips) + body = k8s_client.V1Endpoints(**endpoints) + else: + body = k8s_client.V1ConfigMap(metadata=metadata) + ret = retry(func, self._namespace, body) if retry else func(self._namespace, body) + if ret: + self._kinds.set(name, ret) + return ret + + @catch_kubernetes_errors + def patch_or_create(self, name, annotations, resource_version=None, patch=False, retry=True, ips=None): + if retry is True: + retry = self.retry + return self._patch_or_create(name, annotations, resource_version, patch, retry, ips) + + def patch_or_create_config(self, annotations, resource_version=None, patch=False, retry=True): + # SCOPE-config endpoint requires corresponding service otherwise it might be "cleaned" by k8s master + if self._api.use_endpoints and not patch and not resource_version: + self._should_create_config_service = True + self._create_config_service() + return self.patch_or_create(self.config_path, annotations, resource_version, patch, retry) + + def _create_config_service(self): + metadata = k8s_client.V1ObjectMeta(namespace=self._namespace, name=self.config_path, labels=self._labels) + body = k8s_client.V1Service(metadata=metadata, spec=k8s_client.V1ServiceSpec(cluster_ip='None')) + try: + if not self._api.create_namespaced_service(self._namespace, body): + return + except Exception as e: + if not isinstance(e, k8s_client.rest.ApiException) or e.status != 409: # Service already exists + return logger.exception('create_config_service failed') + self._should_create_config_service = False + + def _write_leader_optime(self, last_operation): + """Unused""" + + def _update_leader(self): + """Unused""" + + def _update_leader_with_retry(self, annotations, resource_version, ips): + retry = self._retry.copy() + + def _retry(*args, **kwargs): + kwargs['_retry'] = retry + return retry(*args, **kwargs) + + try: + return self._patch_or_create(self.leader_path, annotations, resource_version, ips=ips, retry=_retry) + except k8s_client.rest.ApiException as e: + if e.status == 409: + logger.warning('Concurrent update of %s', self.leader_path) + else: + logger.exception('Permission denied' if e.status == 403 else 'Unexpected error from Kubernetes API') + return False + except (RetryFailedError, K8sException): + return False + + retry.deadline = retry.stoptime - time.time() + if retry.deadline < 1: + return False + + # Try to get the latest version directly from K8s API instead of relying on async cache + try: + kind = retry(self._api.read_namespaced_kind, self.leader_path, self._namespace) + except Exception as e: + logger.error('Failed to get the leader object "%s": %r', self.leader_path, e) + return False + + self._kinds.set(self.leader_path, kind) + + retry.deadline = retry.stoptime - time.time() + if retry.deadline < 0.5: + return False + + kind_annotations = kind and kind.metadata.annotations or {} + kind_resource_version = kind and kind.metadata.resource_version + + # There is different leader or resource_version in cache didn't change + if kind and (kind_annotations.get(self._LEADER) != self._name or kind_resource_version == resource_version): + return False + + return self.patch_or_create(self.leader_path, annotations, kind_resource_version, ips=ips, retry=_retry) + + def update_leader(self, last_operation, access_is_restricted=False): + kind = self._kinds.get(self.leader_path) + kind_annotations = kind and kind.metadata.annotations or {} + + if kind and kind_annotations.get(self._LEADER) != self._name: + return False + + now = datetime.datetime.now(tzutc).isoformat() + leader_observed_record = kind_annotations or self._leader_observed_record + annotations = {self._LEADER: self._name, 'ttl': str(self._ttl), 'renewTime': now, + 'acquireTime': leader_observed_record.get('acquireTime') or now, + 'transitions': leader_observed_record.get('transitions') or '0'} + if last_operation: + annotations[self._OPTIME] = last_operation + + resource_version = kind and kind.metadata.resource_version + ips = [] if access_is_restricted else self.__ips + return self._update_leader_with_retry(annotations, resource_version, ips) + + def attempt_to_acquire_leader(self, permanent=False): + now = datetime.datetime.now(tzutc).isoformat() + annotations = {self._LEADER: self._name, 'ttl': str(sys.maxsize if permanent else self._ttl), + 'renewTime': now, 'acquireTime': now, 'transitions': '0'} + if self._leader_observed_record: + try: + transitions = int(self._leader_observed_record.get('transitions')) + except (TypeError, ValueError): + transitions = 0 + + if self._leader_observed_record.get(self._LEADER) != self._name: + transitions += 1 + else: + annotations['acquireTime'] = self._leader_observed_record.get('acquireTime') or now + annotations['transitions'] = str(transitions) + ips = [] if self._api.use_endpoints else None + ret = self.patch_or_create(self.leader_path, annotations, self._leader_resource_version, ips=ips) + if not ret: + logger.info('Could not take out TTL lock') + return ret + + def take_leader(self): + return self.attempt_to_acquire_leader() + + def set_failover_value(self, value, index=None): + """Unused""" + + def manual_failover(self, leader, candidate, scheduled_at=None, index=None): + annotations = {'leader': leader or None, 'member': candidate or None, + 'scheduled_at': scheduled_at and scheduled_at.isoformat()} + patch = bool(self.cluster and isinstance(self.cluster.failover, Failover) and self.cluster.failover.index) + return self.patch_or_create(self.failover_path, annotations, index, bool(index or patch), False) + + @property + def _config_resource_version(self): + config = self._kinds.get(self.config_path) + return config and config.metadata.resource_version + + def set_config_value(self, value, index=None): + return self.patch_or_create_config({self._CONFIG: value}, index, bool(self._config_resource_version), False) + + @catch_kubernetes_errors + def touch_member(self, data, permanent=False): + cluster = self.cluster + if cluster and cluster.leader and cluster.leader.name == self._name: + role = 'promoted' if data['role'] in ('replica', 'promoted') else 'master' + elif data['state'] == 'running' and data['role'] != 'master': + role = data['role'] + else: + role = None + + member = cluster and cluster.get_member(self._name, fallback_to_leader=False) + pod_labels = member and member.data.pop('pod_labels', None) + ret = pod_labels is not None and pod_labels.get(self._role_label) == role and deep_compare(data, member.data) + + if not ret: + metadata = {'namespace': self._namespace, 'name': self._name, 'labels': {self._role_label: role}, + 'annotations': {'status': json.dumps(data, separators=(',', ':'))}} + body = k8s_client.V1Pod(metadata=k8s_client.V1ObjectMeta(**metadata)) + ret = self._api.patch_namespaced_pod(self._name, self._namespace, body) + if ret: + self._pods.set(self._name, ret) + if self._should_create_config_service: + self._create_config_service() + return ret + + def initialize(self, create_new=True, sysid=""): + cluster = self.cluster + resource_version = cluster.config.index if cluster and cluster.config and cluster.config.index else None + return self.patch_or_create_config({self._INITIALIZE: sysid}, resource_version) + + def _delete_leader(self): + """Unused""" + + def delete_leader(self, last_operation=None): + kind = self._kinds.get(self.leader_path) + if kind and (kind.metadata.annotations or {}).get(self._LEADER) == self._name: + annotations = {self._LEADER: None} + if last_operation: + annotations[self._OPTIME] = last_operation + self.patch_or_create(self.leader_path, annotations, kind.metadata.resource_version, True, False, []) + self.reset_cluster() + + def cancel_initialization(self): + self.patch_or_create_config({self._INITIALIZE: None}, self._config_resource_version, True) + + @catch_kubernetes_errors + def delete_cluster(self): + self.retry(self._api.delete_collection_namespaced_kind, self._namespace, label_selector=self._label_selector) + + def set_history_value(self, value): + return self.patch_or_create_config({self._HISTORY: value}, None, bool(self._config_resource_version), False) + + def set_sync_state_value(self, value, index=None): + """Unused""" + + def write_sync_state(self, leader, sync_standby, index=None): + return self.patch_or_create(self.sync_path, self.sync_state(leader, sync_standby), index, False) + + def delete_sync_state(self, index=None): + return self.write_sync_state(None, None, index) + + def watch(self, leader_index, timeout): + if self.__do_not_watch: + self.__do_not_watch = False + return True + + try: + return super(Kubernetes, self).watch(None, timeout + 0.5) + finally: + self.event.clear() diff --git a/patroni-for-openGauss/dcs/raft.py b/patroni-for-openGauss/dcs/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..811e3ff6f4dd675cbd60fad4503ed7a4c9e763c2 --- /dev/null +++ b/patroni-for-openGauss/dcs/raft.py @@ -0,0 +1,422 @@ +import json +import logging +import os +import threading +import time + +from patroni.dcs import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory +from ..utils import validate_directory +from pysyncobj import SyncObj, SyncObjConf, replicated, FAIL_REASON +from pysyncobj.transport import Node, TCPTransport, CONNECTION_STATE + +logger = logging.getLogger(__name__) + + +class MessageNode(Node): + + def __init__(self, address): + self.address = address + + +class UtilityTransport(TCPTransport): + + def __init__(self, syncObj, selfNode, otherNodes): + super(UtilityTransport, self).__init__(syncObj, selfNode, otherNodes) + self._selfIsReadonlyNode = False + + def _connectIfNecessarySingle(self, node): + pass + + def connectionState(self, node): + return self._connections[node].state + + def isDisconnected(self, node): + return self.connectionState(node) == CONNECTION_STATE.DISCONNECTED + + def connectIfRequiredSingle(self, node): + if self.isDisconnected(node): + return self._connections[node].connect(node.ip, node.port) + + def disconnectSingle(self, node): + self._connections[node].disconnect() + + +class SyncObjUtility(SyncObj): + + def __init__(self, otherNodes, conf): + autoTick = conf.autoTick + conf.autoTick = False + super(SyncObjUtility, self).__init__(None, otherNodes, conf, transportClass=UtilityTransport) + conf.autoTick = autoTick + self._SyncObj__transport.setOnMessageReceivedCallback(self._onMessageReceived) + self.__result = None + + def setPartnerNode(self, partner): + self.__node = partner + + def sendMessage(self, message): + # Abuse the fact that node address is send as a first message + self._SyncObj__transport._selfNode = MessageNode(message) + self._SyncObj__transport.connectIfRequiredSingle(self.__node) + while not self._SyncObj__transport.isDisconnected(self.__node): + self._poller.poll(0.5) + return self.__result + + def _onMessageReceived(self, _, message): + self.__result = message + self._SyncObj__transport.disconnectSingle(self.__node) + + +class MyTCPTransport(TCPTransport): + + def _onIncomingMessageReceived(self, conn, message): + if self._syncObj.encryptor and not conn.sendRandKey: + conn.sendRandKey = message + conn.recvRandKey = os.urandom(32) + conn.send(conn.recvRandKey) + return + + # Utility messages + if isinstance(message, list) and message[0] == 'members': + conn.send(self._syncObj._get_members()) + return True + + return super(MyTCPTransport, self)._onIncomingMessageReceived(conn, message) + + +class DynMemberSyncObj(SyncObj): + + def __init__(self, selfAddress, partnerAddrs, conf): + add_self = False + utility = SyncObjUtility(partnerAddrs, conf) + for node in utility._SyncObj__otherNodes: + utility.setPartnerNode(node) + response = utility.sendMessage(['members']) + if response: + partnerAddrs = [member['addr'] for member in response if member['addr'] != selfAddress] + add_self = selfAddress and len(partnerAddrs) == len(response) + break + + super(DynMemberSyncObj, self).__init__(selfAddress, partnerAddrs, conf, transportClass=MyTCPTransport) + if add_self: + threading.Thread(target=utility.sendMessage, args=(['add', selfAddress],)).start() + + def _get_members(self): + ret = [{'addr': node.id, 'leader': node == self._getLeader(), + 'status': CONNECTION_STATE.CONNECTED if node in self._SyncObj__connectedNodes + else CONNECTION_STATE.DISCONNECTED} for node in self._SyncObj__otherNodes] + ret.append({'addr': self._SyncObj__selfNode.id, 'leader': self._isLeader(), + 'status': CONNECTION_STATE.CONNECTED}) + return ret + + def _SyncObj__doChangeCluster(self, request, reverse=False): + ret = False + if not self._SyncObj__selfNode or request[0] != 'add' or reverse or request[1] != self._SyncObj__selfNode.id: + ret = super(DynMemberSyncObj, self)._SyncObj__doChangeCluster(request, reverse) + if ret: + self.forceLogCompaction() + return ret + + +class KVStoreTTL(DynMemberSyncObj): + + def __init__(self, selfAddress, partnerAddrs, conf, on_set=None, on_delete=None): + self.__on_set = on_set + self.__on_delete = on_delete + self.__limb = {} + self.__retry_timeout = None + self.__early_apply_local_log = selfAddress is not None + self.applied_local_log = False + super(KVStoreTTL, self).__init__(selfAddress, partnerAddrs, conf) + self.__data = {} + + @staticmethod + def __check_requirements(old_value, **kwargs): + return ('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value)) and \ + ('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue']) and \ + (not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex']) + + def set_retry_timeout(self, retry_timeout): + self.__retry_timeout = retry_timeout + + def retry(self, func, *args, **kwargs): + event = threading.Event() + ret = {'result': None, 'error': -1} + + def callback(result, error): + ret.update(result=result, error=error) + event.set() + + kwargs['callback'] = callback + timeout = kwargs.pop('timeout', None) or self.__retry_timeout + deadline = timeout and time.time() + timeout + + while True: + event.clear() + func(*args, **kwargs) + event.wait(timeout) + if ret['error'] == FAIL_REASON.SUCCESS: + return ret['result'] + elif ret['error'] == FAIL_REASON.REQUEST_DENIED: + break + elif deadline: + timeout = deadline - time.time() + if timeout <= 0: + break + time.sleep(1) + return False + + @replicated + def _set(self, key, value, **kwargs): + old_value = self.__data.get(key, {}) + if not self.__check_requirements(old_value, **kwargs): + return False + + if old_value and old_value['created'] != value['created']: + value['created'] = value['updated'] + value['index'] = self._SyncObj__raftLastApplied + 1 + + self.__data[key] = value + if self.__on_set: + self.__on_set(key, value) + return True + + def set(self, key, value, ttl=None, **kwargs): + old_value = self.__data.get(key, {}) + if not self.__check_requirements(old_value, **kwargs): + return False + + value = {'value': value, 'updated': time.time()} + value['created'] = old_value.get('created', value['updated']) + if ttl: + value['expire'] = value['updated'] + ttl + return self.retry(self._set, key, value, **kwargs) + + def __pop(self, key): + self.__data.pop(key) + if self.__on_delete: + self.__on_delete(key) + + @replicated + def _delete(self, key, recursive=False, **kwargs): + if recursive: + for k in list(self.__data.keys()): + if k.startswith(key): + self.__pop(k) + elif not self.__check_requirements(self.__data.get(key, {}), **kwargs): + return False + else: + self.__pop(key) + return True + + def delete(self, key, recursive=False, **kwargs): + if not recursive and not self.__check_requirements(self.__data.get(key, {}), **kwargs): + return False + return self.retry(self._delete, key, recursive=recursive, **kwargs) + + @staticmethod + def __values_match(old, new): + return all(old.get(n) == new.get(n) for n in ('created', 'updated', 'expire', 'value')) + + @replicated + def _expire(self, key, value, callback=None): + current = self.__data.get(key) + if current and self.__values_match(current, value): + self.__pop(key) + + def __expire_keys(self): + for key, value in self.__data.items(): + if value and 'expire' in value and value['expire'] <= time.time() and \ + not (key in self.__limb and self.__values_match(self.__limb[key], value)): + self.__limb[key] = value + + def callback(*args): + if key in self.__limb and self.__values_match(self.__limb[key], value): + self.__limb.pop(key) + self._expire(key, value, callback=callback) + + def get(self, key, recursive=False): + if not recursive: + return self.__data.get(key) + return {k: v for k, v in self.__data.items() if k.startswith(key)} + + def _onTick(self, timeToWait=0.0): + # The SyncObj starts applying the local log only when there is at least one node connected. + # We want to change this behavior and apply the local log even when there is nobody except us. + # It gives us at least some picture about the last known cluster state. + if self.__early_apply_local_log and not self.applied_local_log and self._SyncObj__needLoadDumpFile: + self._SyncObj__raftCommitIndex = self._SyncObj__getCurrentLogIndex() + self._SyncObj__raftCurrentTerm = self._SyncObj__getCurrentLogTerm() + + super(KVStoreTTL, self)._onTick(timeToWait) + + # The SyncObj calls onReady callback only when cluster got the leader and is ready for writes. + # In some cases for us it is safe to "signal" the Raft object when the local log is fully applied. + # We are using the `applied_local_log` property for that, but not calling the callback function. + if self.__early_apply_local_log and not self.applied_local_log and self._SyncObj__raftCommitIndex != 1 and \ + self._SyncObj__raftLastApplied == self._SyncObj__raftCommitIndex: + self.applied_local_log = True + + if self._isLeader(): + self.__expire_keys() + else: + self.__limb.clear() + + +class Raft(AbstractDCS): + + def __init__(self, config): + super(Raft, self).__init__(config) + self._ttl = int(config.get('ttl') or 30) + + self_addr = config.get('self_addr') + partner_addrs = config.get('partner_addrs', []) + if self._ctl: + if self_addr: + partner_addrs.append(self_addr) + self_addr = None + + # Create raft data_dir if necessary + raft_data_dir = config.get('data_dir', '') + if raft_data_dir != '': + validate_directory(raft_data_dir) + + ready_event = threading.Event() + file_template = os.path.join(config.get('data_dir', ''), (self_addr or '')) + conf = SyncObjConf(password=config.get('password'), appendEntriesUseBatch=False, + bindAddress=config.get('bind_addr'), commandsWaitLeader=False, + fullDumpFile=(file_template + '.dump' if self_addr else None), + journalFile=(file_template + '.journal' if self_addr else None), + onReady=ready_event.set, dynamicMembershipChange=True) + + self._sync_obj = KVStoreTTL(self_addr, partner_addrs, conf, self._on_set, self._on_delete) + while True: + ready_event.wait(5) + if ready_event.isSet() or self._sync_obj.applied_local_log: + break + else: + logger.info('waiting on raft') + self._sync_obj.forceLogCompaction() + self.set_retry_timeout(int(config.get('retry_timeout') or 10)) + + def _on_set(self, key, value): + leader = (self._sync_obj.get(self.leader_path) or {}).get('value') + if key == value['created'] == value['updated'] and \ + (key.startswith(self.members_path) or key == self.leader_path and leader != self._name) or \ + key == self.leader_optime_path and leader != self._name or key in (self.config_path, self.sync_path): + self.event.set() + + def _on_delete(self, key): + if key == self.leader_path: + self.event.set() + + def set_ttl(self, ttl): + self._ttl = ttl + + @property + def ttl(self): + return self._ttl + + def set_retry_timeout(self, retry_timeout): + self._sync_obj.set_retry_timeout(retry_timeout) + + @staticmethod + def member(key, value): + return Member.from_node(value['index'], os.path.basename(key), None, value['value']) + + def _load_cluster(self): + prefix = self.client_path('') + response = self._sync_obj.get(prefix, recursive=True) + if not response: + return Cluster(None, None, None, None, [], None, None, None) + nodes = {os.path.relpath(key, prefix).replace('\\', '/'): value for key, value in response.items()} + + # get initialize flag + initialize = nodes.get(self._INITIALIZE) + initialize = initialize and initialize['value'] + + # get global dynamic configuration + config = nodes.get(self._CONFIG) + config = config and ClusterConfig.from_node(config['index'], config['value']) + + # get timeline history + history = nodes.get(self._HISTORY) + history = history and TimelineHistory.from_node(history['index'], history['value']) + + # get last leader operation + last_leader_operation = nodes.get(self._LEADER_OPTIME) + last_leader_operation = 0 if last_leader_operation is None else int(last_leader_operation['value']) + + # get list of members + members = [self.member(k, n) for k, n in nodes.items() if k.startswith(self._MEMBERS) and k.count('/') == 1] + + # get leader + leader = nodes.get(self._LEADER) + if leader: + member = Member(-1, leader['value'], None, {}) + member = ([m for m in members if m.name == leader['value']] or [member])[0] + leader = Leader(leader['index'], None, member) + + # failover key + failover = nodes.get(self._FAILOVER) + if failover: + failover = Failover.from_node(failover['index'], failover['value']) + + # get synchronization state + sync = nodes.get(self._SYNC) + sync = SyncState.from_node(sync and sync['index'], sync and sync['value']) + + return Cluster(initialize, config, leader, last_leader_operation, members, failover, sync, history) + + def _write_leader_optime(self, last_operation): + return self._sync_obj.set(self.leader_optime_path, last_operation, timeout=1) + + def _update_leader(self): + ret = self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl, prevValue=self._name) + if not ret and self._sync_obj.get(self.leader_path) is None: + ret = self.attempt_to_acquire_leader() + return ret + + def attempt_to_acquire_leader(self, permanent=False): + return self._sync_obj.set(self.leader_path, self._name, prevExist=False, + ttl=None if permanent else self._ttl) + + def set_failover_value(self, value, index=None): + return self._sync_obj.set(self.failover_path, value, prevIndex=index) + + def set_config_value(self, value, index=None): + return self._sync_obj.set(self.config_path, value, prevIndex=index) + + def touch_member(self, data, permanent=False): + data = json.dumps(data, separators=(',', ':')) + return self._sync_obj.set(self.member_path, data, None if permanent else self._ttl, timeout=2) + + def take_leader(self): + return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl) + + def initialize(self, create_new=True, sysid=''): + return self._sync_obj.set(self.initialize_path, sysid, prevExist=(not create_new)) + + def _delete_leader(self): + return self._sync_obj.delete(self.leader_path, prevValue=self._name, timeout=1) + + def cancel_initialization(self): + return self._sync_obj.delete(self.initialize_path) + + def delete_cluster(self): + return self._sync_obj.delete(self.client_path(''), recursive=True) + + def set_history_value(self, value): + return self._sync_obj.set(self.history_path, value) + + def set_sync_state_value(self, value, index=None): + return self._sync_obj.set(self.sync_path, value, prevIndex=index) + + def delete_sync_state(self, index=None): + return self._sync_obj.delete(self.sync_path, prevIndex=index) + + def watch(self, leader_index, timeout): + try: + return super(Raft, self).watch(leader_index, timeout) + finally: + self.event.clear() diff --git a/patroni-for-openGauss/dcs/zookeeper.py b/patroni-for-openGauss/dcs/zookeeper.py new file mode 100644 index 0000000000000000000000000000000000000000..a7732cd4f26cab20a64996e3ccb509a905efe272 --- /dev/null +++ b/patroni-for-openGauss/dcs/zookeeper.py @@ -0,0 +1,379 @@ +import json +import logging +import select +import time + +from kazoo.client import KazooClient, KazooState, KazooRetry +from kazoo.exceptions import NoNodeError, NodeExistsError +from kazoo.handlers.threading import SequentialThreadingHandler +from patroni.dcs import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory +from patroni.exceptions import DCSError +from patroni.utils import deep_compare + +logger = logging.getLogger(__name__) + + +class ZooKeeperError(DCSError): + pass + + +class PatroniSequentialThreadingHandler(SequentialThreadingHandler): + + def __init__(self, connect_timeout): + super(PatroniSequentialThreadingHandler, self).__init__() + self.set_connect_timeout(connect_timeout) + + def set_connect_timeout(self, connect_timeout): + self._connect_timeout = max(1.0, connect_timeout/2.0) # try to connect to zookeeper node during loop_wait/2 + + def create_connection(self, *args, **kwargs): + """This method is trying to establish connection with one of the zookeeper nodes. + Somehow strategy "fail earlier and retry more often" works way better comparing to + the original strategy "try to connect with specified timeout". + Since we want to try connect to zookeeper more often (with the smaller connect_timeout), + he have to override `create_connection` method in the `SequentialThreadingHandler` + class (which is used by `kazoo.Client`). + + :param args: always contains `tuple(host, port)` as the first element and could contain + `connect_timeout` (negotiated session timeout) as the second element.""" + + args = list(args) + if len(args) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method + kwargs['timeout'] = max(self._connect_timeout, kwargs.get('timeout', self._connect_timeout*10)/10.0) + elif len(args) == 1: + args.append(self._connect_timeout) + else: + args[1] = max(self._connect_timeout, args[1]/10.0) + return super(PatroniSequentialThreadingHandler, self).create_connection(*args, **kwargs) + + def select(self, *args, **kwargs): + """Python3 raises `ValueError` if socket is closed, because fd == -1""" + try: + return super(PatroniSequentialThreadingHandler, self).select(*args, **kwargs) + except ValueError as e: + raise select.error(9, str(e)) + + +class ZooKeeper(AbstractDCS): + + def __init__(self, config): + super(ZooKeeper, self).__init__(config) + + hosts = config.get('hosts', []) + if isinstance(hosts, list): + hosts = ','.join(hosts) + + mapping = {'use_ssl': 'use_ssl', 'verify': 'verify_certs', 'cacert': 'ca', + 'cert': 'certfile', 'key': 'keyfile', 'key_password': 'keyfile_password'} + kwargs = {v: config[k] for k, v in mapping.items() if k in config} + + self._client = KazooClient(hosts, handler=PatroniSequentialThreadingHandler(config['retry_timeout']), + timeout=config['ttl'], connection_retry=KazooRetry(max_delay=1, max_tries=-1, + sleep_func=time.sleep), command_retry=KazooRetry(deadline=config['retry_timeout'], + max_delay=1, max_tries=-1, sleep_func=time.sleep), **kwargs) + self._client.add_listener(self.session_listener) + + self._fetch_cluster = True + self._fetch_optime = True + + self._orig_kazoo_connect = self._client._connection._connect + self._client._connection._connect = self._kazoo_connect + + self._client.start() + + def _kazoo_connect(self, *args): + """Kazoo is using Ping's to determine health of connection to zookeeper. If there is no + response on Ping after Ping interval (1/2 from read_timeout) it will consider current + connection dead and try to connect to another node. Without this "magic" it was taking + up to 2/3 from session timeout (ttl) to figure out that connection was dead and we had + only small time for reconnect and retry. + + This method is needed to return different value of read_timeout, which is not calculated + from negotiated session timeout but from value of `loop_wait`. And it is 2 sec smaller + than loop_wait, because we can spend up to 2 seconds when calling `touch_member()` and + `write_leader_optime()` methods, which also may hang...""" + + ret = self._orig_kazoo_connect(*args) + return max(self.loop_wait - 2, 2)*1000, ret[1] + + def session_listener(self, state): + if state in [KazooState.SUSPENDED, KazooState.LOST]: + self.cluster_watcher(None) + + def optime_watcher(self, event): + self._fetch_optime = True + self.event.set() + + def cluster_watcher(self, event): + self._fetch_cluster = True + self.optime_watcher(event) + + def reload_config(self, config): + self.set_retry_timeout(config['retry_timeout']) + + loop_wait = config['loop_wait'] + + loop_wait_changed = self._loop_wait != loop_wait + self._loop_wait = loop_wait + self._client.handler.set_connect_timeout(loop_wait) + + # We need to reestablish connection to zookeeper if we want to change + # read_timeout (and Ping interval respectively), because read_timeout + # is calculated in `_kazoo_connect` method. If we are changing ttl at + # the same time, set_ttl method will reestablish connection and return + # `!True`, otherwise we will close existing connection and let kazoo + # open the new one. + if not self.set_ttl(config['ttl']) and loop_wait_changed: + self._client._connection._socket.close() + + def set_ttl(self, ttl): + """It is not possible to change ttl (session_timeout) in zookeeper without + destroying old session and creating the new one. This method returns `!True` + if session_timeout has been changed (`restart()` has been called).""" + ttl = int(ttl * 1000) + if self._client._session_timeout != ttl: + self._client._session_timeout = ttl + self._client.restart() + return True + + @property + def ttl(self): + return self._client._session_timeout / 1000.0 + + def set_retry_timeout(self, retry_timeout): + retry = self._client.retry if isinstance(self._client.retry, KazooRetry) else self._client._retry + retry.deadline = retry_timeout + + def get_node(self, key, watch=None): + try: + ret = self._client.get(key, watch) + return (ret[0].decode('utf-8'), ret[1]) + except NoNodeError: + return None + + def get_leader_optime(self, leader): + watch = self.optime_watcher if not leader or leader.name != self._name else None + optime = self.get_node(self.leader_optime_path, watch) + self._fetch_optime = False + return optime and int(optime[0]) or 0 + + @staticmethod + def member(name, value, znode): + return Member.from_node(znode.version, name, znode.ephemeralOwner, value) + + def get_children(self, key, watch=None): + try: + return self._client.get_children(key, watch) + except NoNodeError: + return [] + + def load_members(self, sync_standby): + members = [] + for member in self.get_children(self.members_path, self.cluster_watcher): + watch = member in sync_standby and self.cluster_watcher or None + data = self.get_node(self.members_path + member, watch) + if data is not None: + members.append(self.member(member, *data)) + return members + + def _inner_load_cluster(self): + self._fetch_cluster = False + self.event.clear() + nodes = set(self.get_children(self.client_path(''), self.cluster_watcher)) + if not nodes: + self._fetch_cluster = True + + # get initialize flag + initialize = (self.get_node(self.initialize_path) or [None])[0] if self._INITIALIZE in nodes else None + + # get global dynamic configuration + config = self.get_node(self.config_path, watch=self.cluster_watcher) if self._CONFIG in nodes else None + config = config and ClusterConfig.from_node(config[1].version, config[0], config[1].mzxid) + + # get timeline history + history = self.get_node(self.history_path, watch=self.cluster_watcher) if self._HISTORY in nodes else None + history = history and TimelineHistory.from_node(history[1].mzxid, history[0]) + + # get synchronization state + sync = self.get_node(self.sync_path, watch=self.cluster_watcher) if self._SYNC in nodes else None + sync = SyncState.from_node(sync and sync[1].version, sync and sync[0]) + + # get list of members + sync_standby = sync.leader == self._name and sync.members or [] + members = self.load_members(sync_standby) if self._MEMBERS[:-1] in nodes else [] + + # get leader + leader = self.get_node(self.leader_path) if self._LEADER in nodes else None + if leader: + client_id = self._client.client_id + if not self._ctl and leader[0] == self._name and client_id is not None \ + and client_id[0] != leader[1].ephemeralOwner: + logger.info('I am leader but not owner of the session. Removing leader node') + self._client.delete(self.leader_path) + leader = None + + if leader: + member = Member(-1, leader[0], None, {}) + member = ([m for m in members if m.name == leader[0]] or [member])[0] + leader = Leader(leader[1].version, leader[1].ephemeralOwner, member) + self._fetch_cluster = member.index == -1 + + # get last leader operation + last_leader_operation = self._OPTIME in nodes and self.get_leader_optime(leader) + + # failover key + failover = self.get_node(self.failover_path, watch=self.cluster_watcher) if self._FAILOVER in nodes else None + failover = failover and Failover.from_node(failover[1].version, failover[0]) + + return Cluster(initialize, config, leader, last_leader_operation, members, failover, sync, history) + + def _load_cluster(self): + cluster = self.cluster + if self._fetch_cluster or cluster is None: + try: + cluster = self._client.retry(self._inner_load_cluster) + except Exception: + logger.exception('get_cluster') + self.cluster_watcher(None) + raise ZooKeeperError('ZooKeeper in not responding properly') + # Optime ZNode was updated or doesn't exist and we are not leader + elif (self._fetch_optime and not self._fetch_cluster or not cluster.last_leader_operation) and\ + not (cluster.leader and cluster.leader.name == self._name): + try: + optime = self.get_leader_optime(cluster.leader) + cluster = Cluster(cluster.initialize, cluster.config, cluster.leader, optime, + cluster.members, cluster.failover, cluster.sync, cluster.history) + except Exception: + pass + return cluster + + def _bypass_caches(self): + self._fetch_cluster = True + + def _create(self, path, value, retry=False, ephemeral=False): + try: + if retry: + self._client.retry(self._client.create, path, value, makepath=True, ephemeral=ephemeral) + else: + self._client.create_async(path, value, makepath=True, ephemeral=ephemeral).get(timeout=1) + return True + except Exception: + logger.exception('Failed to create %s', path) + return False + + def attempt_to_acquire_leader(self, permanent=False): + ret = self._create(self.leader_path, self._name.encode('utf-8'), retry=True, ephemeral=not permanent) + if not ret: + logger.info('Could not take out TTL lock') + return ret + + def _set_or_create(self, key, value, index=None, retry=False, do_not_create_empty=False): + value = value.encode('utf-8') + try: + if retry: + self._client.retry(self._client.set, key, value, version=index or -1) + else: + self._client.set_async(key, value, version=index or -1).get(timeout=1) + return True + except NoNodeError: + if do_not_create_empty and not value: + return True + elif index is None: + return self._create(key, value, retry) + else: + return False + except Exception: + logger.exception('Failed to update %s', key) + return False + + def set_failover_value(self, value, index=None): + return self._set_or_create(self.failover_path, value, index) + + def set_config_value(self, value, index=None): + return self._set_or_create(self.config_path, value, index, retry=True) + + def initialize(self, create_new=True, sysid=""): + sysid = sysid.encode('utf-8') + return self._create(self.initialize_path, sysid, retry=True) if create_new \ + else self._client.retry(self._client.set, self.initialize_path, sysid) + + def touch_member(self, data, permanent=False): + cluster = self.cluster + member = cluster and cluster.get_member(self._name, fallback_to_leader=False) + encoded_data = json.dumps(data, separators=(',', ':')).encode('utf-8') + if member and (self._client.client_id is not None and member.session != self._client.client_id[0] or + not (deep_compare(member.data.get('tags', {}), data.get('tags', {})) and + member.data.get('version') == data.get('version') and + member.data.get('checkpoint_after_promote') == data.get('checkpoint_after_promote'))): + try: + self._client.delete_async(self.member_path).get(timeout=1) + except NoNodeError: + pass + except Exception: + return False + member = None + + if member: + if deep_compare(data, member.data): + return True + else: + try: + self._client.create_async(self.member_path, encoded_data, makepath=True, + ephemeral=not permanent).get(timeout=1) + return True + except Exception as e: + if not isinstance(e, NodeExistsError): + logger.exception('touch_member') + return False + try: + self._client.set_async(self.member_path, encoded_data).get(timeout=1) + return True + except Exception: + logger.exception('touch_member') + + return False + + def take_leader(self): + return self.attempt_to_acquire_leader() + + def _write_leader_optime(self, last_operation): + return self._set_or_create(self.leader_optime_path, last_operation) + + def _update_leader(self): + return True + + def _delete_leader(self): + self._client.restart() + return True + + def _cancel_initialization(self): + node = self.get_node(self.initialize_path) + if node: + self._client.delete(self.initialize_path, version=node[1].version) + + def cancel_initialization(self): + try: + self._client.retry(self._cancel_initialization) + except Exception: + logger.exception("Unable to delete initialize key") + + def delete_cluster(self): + try: + return self._client.retry(self._client.delete, self.client_path(''), recursive=True) + except NoNodeError: + return True + + def set_history_value(self, value): + return self._set_or_create(self.history_path, value) + + def set_sync_state_value(self, value, index=None): + return self._set_or_create(self.sync_path, value, index, retry=True, do_not_create_empty=True) + + def delete_sync_state(self, index=None): + return self.set_sync_state_value("{}", index) + + def watch(self, leader_index, timeout): + ret = super(ZooKeeper, self).watch(leader_index, timeout) + if ret and not self._fetch_optime: + self._fetch_cluster = True + return ret or self._fetch_cluster diff --git a/patroni-for-openGauss/exceptions.py b/patroni-for-openGauss/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..21df1f7b5f3682672d75049fe79feb3940918233 --- /dev/null +++ b/patroni-for-openGauss/exceptions.py @@ -0,0 +1,37 @@ +class PatroniException(Exception): + + """Parent class for all kind of exceptions related to selected distributed configuration store""" + + def __init__(self, value): + self.value = value + + def __str__(self): + """ + >>> str(PatroniException('foo')) + "'foo'" + """ + return repr(self.value) + + +class PatroniFatalException(PatroniException): + pass + + +class PostgresException(PatroniException): + pass + + +class DCSError(PatroniException): + pass + + +class PostgresConnectionException(PostgresException): + pass + + +class WatchdogError(PatroniException): + pass + + +class ConfigParseError(PatroniException): + pass diff --git a/patroni-for-openGauss/ha.py b/patroni-for-openGauss/ha.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0c0b9582e007c93bb176b8f5d7af811edd1fdb --- /dev/null +++ b/patroni-for-openGauss/ha.py @@ -0,0 +1,1560 @@ +import datetime +import functools +import json +import logging +import psycopg2 +import sys +import time +import uuid + +from collections import namedtuple +from multiprocessing.pool import ThreadPool +from patroni.async_executor import AsyncExecutor, CriticalTask +from patroni.exceptions import DCSError, PostgresConnectionException, PatroniFatalException +from patroni.postgresql import ACTION_ON_START, ACTION_ON_ROLE_CHANGE +from patroni.postgresql.misc import postgres_version_to_int +from patroni.postgresql.rewind import Rewind +from patroni.utils import polling_loop, tzutc, is_standby_cluster as _is_standby_cluster, parse_int +from patroni.dcs import RemoteMember +from threading import RLock + +logger = logging.getLogger(__name__) + + +class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_recovery', 'timeline', + 'wal_position', 'tags', 'watchdog_failed'])): + """Node status distilled from API response: + + member - dcs.Member object of the node + reachable - `!False` if the node is not reachable or is not responding with correct JSON + in_recovery - `!True` if pg_is_in_recovery() == true + timeline - timeline value from JSON + wal_position - maximum value of `replayed_location` or `received_location` from JSON + tags - dictionary with values of different tags (i.e. nofailover) + watchdog_failed - indicates that watchdog is required by configuration but not available or failed + """ + @classmethod + def from_api_response(cls, member, json): + is_master = json['role'] == 'master' + timeline = json.get('timeline', 0) + wal = not is_master and max(json['xlog'].get('received_location', 0), json['xlog'].get('replayed_location', 0)) + return cls(member, True, not is_master, timeline, wal, json.get('tags', {}), json.get('watchdog_failed', False)) + + @classmethod + def unknown(cls, member): + return cls(member, False, None, 0, 0, {}, False) + + def failover_limitation(self): + """Returns reason why this node can't promote or None if everything is ok.""" + if not self.reachable: + return 'not reachable' + if self.tags.get('nofailover', False): + return 'not allowed to promote' + if self.watchdog_failed: + return 'not watchdog capable' + return None + + +class Ha(object): + + def __init__(self, patroni): + self.patroni = patroni + self.state_handler = patroni.postgresql + self._rewind = Rewind(self.state_handler) + self.dcs = patroni.dcs + self.cluster = None + self.old_cluster = None + self._is_leader = False + self._is_leader_lock = RLock() + self._leader_access_is_restricted = False + self._was_paused = False + self._leader_timeline = None + self.recovering = False + self._async_response = CriticalTask() + self._crash_recovery_executed = False + self._crash_recovery_started = None + self._start_timeout = None + self._async_executor = AsyncExecutor(self.state_handler.cancellable, self.wakeup) + self.watchdog = patroni.watchdog + + # Each member publishes various pieces of information to the DCS using touch_member. This lock protects + # the state and publishing procedure to have consistent ordering and avoid publishing stale values. + self._member_state_lock = RLock() + # Count of concurrent sync disabling requests. Value above zero means that we don't want to be synchronous + # standby. Changes protected by _member_state_lock. + self._disable_sync = 0 + + # We need following property to avoid shutdown of postgres when join of Patroni to the postgres + # already running as replica was aborted due to cluster not beeing initialized in DCS. + self._join_aborted = False + + # used only in backoff after failing a pre_promote script + self._released_leader_key_timestamp = 0 + + def check_mode(self, mode): + # Try to protect from the case when DCS was wiped out during pause + if self.cluster and self.cluster.config and self.cluster.config.modify_index: + return self.cluster.check_mode(mode) + else: + return self.patroni.config.check_mode(mode) + + def master_stop_timeout(self): + """ Master stop timeout """ + ret = parse_int(self.patroni.config['master_stop_timeout']) + return ret if ret and ret > 0 and self.is_synchronous_mode() else None + + def is_paused(self): + return self.check_mode('pause') + + def check_timeline(self): + return self.check_mode('check_timeline') + + def get_standby_cluster_config(self): + if self.cluster and self.cluster.config and self.cluster.config.modify_index: + config = self.cluster.config.data + else: + config = self.patroni.config.dynamic_configuration + return config.get('standby_cluster') + + def is_standby_cluster(self): + return _is_standby_cluster(self.get_standby_cluster_config()) + + def is_leader(self): + with self._is_leader_lock: + return self._is_leader > time.time() and not self._leader_access_is_restricted + + def set_is_leader(self, value): + with self._is_leader_lock: + self._is_leader = time.time() + self.dcs.ttl if value else 0 + + def set_leader_access_is_restricted(self, value): + with self._is_leader_lock: + self._leader_access_is_restricted = value + + def load_cluster_from_dcs(self): + cluster = self.dcs.get_cluster() + + # We want to keep the state of cluster when it was healthy + if not cluster.is_unlocked() or not self.old_cluster: + self.old_cluster = cluster + self.cluster = cluster + + if not self.has_lock(False): + self.set_is_leader(False) + + self._leader_timeline = None if cluster.is_unlocked() else cluster.leader.timeline + + def acquire_lock(self): + self.set_leader_access_is_restricted(self.cluster.has_permanent_logical_slots(self.state_handler.name)) + ret = self.dcs.attempt_to_acquire_leader() + self.set_is_leader(ret) + return ret + + def update_lock(self, write_leader_optime=False): + last_operation = None + if write_leader_optime: + try: + last_operation = self.state_handler.last_operation() + except Exception: + logger.exception('Exception when called state_handler.last_operation()') + try: + ret = self.dcs.update_leader(last_operation, self._leader_access_is_restricted) + except Exception: + logger.exception('Unexpected exception raised from update_leader, please report it as a BUG') + ret = False + self.set_is_leader(ret) + if ret: + self.watchdog.keepalive() + return ret + + def has_lock(self, info=True): + lock_owner = self.cluster.leader and self.cluster.leader.name + if info: + logger.info('Lock owner: %s; I am %s', lock_owner, self.state_handler.name) + return lock_owner == self.state_handler.name + + def get_effective_tags(self): + """Return configuration tags merged with dynamically applied tags.""" + tags = self.patroni.tags.copy() + # _disable_sync could be modified concurrently, but we don't care as attribute get and set are atomic. + if self._disable_sync > 0: + tags['nosync'] = True + return tags + + def touch_member(self): + with self._member_state_lock: + data = { + 'conn_url': self.state_handler.connection_string, + 'api_url': self.patroni.api.connection_string, + 'state': self.state_handler.state, + 'role': self.state_handler.role, + 'version': self.patroni.version + } + + # following two lines are mainly necessary for consul, to avoid creation of master service + if data['role'] == 'master' and not self.is_leader(): + data['role'] = 'promoted' + if self.is_leader() and not self._rewind.checkpoint_after_promote(): + data['checkpoint_after_promote'] = False + tags = self.get_effective_tags() + if tags: + data['tags'] = tags + if self.state_handler.pending_restart: + data['pending_restart'] = True + if self._async_executor.scheduled_action in (None, 'promote') \ + and data['state'] in ['running', 'restarting', 'starting']: + try: + timeline, wal_position, pg_control_timeline = self.state_handler.timeline_wal_position() + data['xlog_location'] = wal_position + if not timeline: # try pg_stat_wal_receiver to get the timeline + timeline = self.state_handler.received_timeline() + if not timeline: + # So far the only way to get the current timeline on the standby is from + # the replication connection. In order to avoid opening the replication + # connection on every iteration of HA loop we will do it only when noticed + # that the timeline on the primary has changed. + # Unfortunately such optimization isn't possible on the standby_leader, + # therefore we will get the timeline from pg_control, either by calling + # pg_control_checkpoint() on 9.6+ or by parsing the output of pg_controldata. + if self.state_handler.role == 'standby_leader': + timeline = pg_control_timeline or self.state_handler.pg_control_timeline() + else: + timeline = self.state_handler.replica_cached_timeline(self._leader_timeline) + if timeline: + data['timeline'] = timeline + except Exception: + pass + if self.patroni.scheduled_restart: + scheduled_restart_data = self.patroni.scheduled_restart.copy() + scheduled_restart_data['schedule'] = scheduled_restart_data['schedule'].isoformat() + data['scheduled_restart'] = scheduled_restart_data + + if self.is_paused(): + data['pause'] = True + + return self.dcs.touch_member(data) + + def clone(self, clone_member=None, msg='(without leader)'): + if self.is_standby_cluster() and not isinstance(clone_member, RemoteMember): + clone_member = self.get_remote_member(clone_member) + + self._rewind.reset_state() + if self.state_handler.bootstrap.clone(clone_member): + logger.info('bootstrapped %s', msg) + cluster = self.dcs.get_cluster() + node_to_follow = self._get_node_to_follow(cluster) + return self.state_handler.follow(node_to_follow) + else: + logger.error('failed to bootstrap %s', msg) + self.state_handler.remove_data_directory() + + def bootstrap(self): + if not self.cluster.is_unlocked(): # cluster already has leader + clone_member = self.cluster.get_clone_member(self.state_handler.name) + member_role = 'leader' if clone_member == self.cluster.leader else 'replica' + msg = "from {0} '{1}'".format(member_role, clone_member.name) + ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg)) + return ret or 'trying to bootstrap {0}'.format(msg) + + # no initialize key and node is allowed to be master and has 'bootstrap' section in a configuration file + elif self.cluster.initialize is None and not self.patroni.nofailover and 'bootstrap' in self.patroni.config: + if self.dcs.initialize(create_new=True): # race for initialization + self.state_handler.bootstrapping = True + with self._async_response: + self._async_response.reset() + + if self.is_standby_cluster(): + ret = self._async_executor.try_run_async('bootstrap_standby_leader', self.bootstrap_standby_leader) + return ret or 'trying to bootstrap a new standby leader' + else: + ret = self._async_executor.try_run_async('bootstrap', self.state_handler.bootstrap.bootstrap, + args=(self.patroni.config['bootstrap'],)) + return ret or 'trying to bootstrap a new cluster' + else: + return 'failed to acquire initialize lock' + else: + create_replica_methods = self.get_standby_cluster_config().get('create_replica_methods', []) \ + if self.is_standby_cluster() else None + if self.state_handler.can_create_replica_without_replication_connection(create_replica_methods): + msg = 'bootstrap (without leader)' + return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg + return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '') + + def bootstrap_standby_leader(self): + """ If we found 'standby' key in the configuration, we need to bootstrap + not a real master, but a 'standby leader', that will take base backup + from a remote master and start follow it. + """ + clone_source = self.get_remote_master() + msg = 'clone from remote master {0}'.format(clone_source.conn_url) + result = self.clone(clone_source, msg) + with self._async_response: # pretend that post_bootstrap was already executed + self._async_response.complete(result) + if result: + self.state_handler.set_role('standby_leader') + + return result + + def _handle_rewind_or_reinitialize(self): + leader = self.get_remote_master() if self.is_standby_cluster() else self.cluster.leader + if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader): + return None + + if self._rewind.can_rewind: + msg = 'running pg_rewind from ' + leader.name + return self._async_executor.try_run_async(msg, self._rewind.execute, args=(leader,)) or msg + + # remove_data_directory_on_diverged_timelines is set + if not self.is_standby_cluster(): + msg = 'reinitializing due to diverged timelines' + return self._async_executor.try_run_async(msg, self._do_reinitialize, args=(self.cluster,)) or msg + + def recover(self): + # Postgres is not running and we will restart in standby mode. Watchdog is not needed until we promote. + self.watchdog.disable() + + if self.has_lock() and self.update_lock(): + timeout = self.patroni.config['master_start_timeout'] + if timeout == 0: + # We are requested to prefer failing over to restarting master. But see first if there + # is anyone to fail over to. + members = self.cluster.members + if self.is_synchronous_mode(): + members = [m for m in members if self.cluster.sync.matches(m.name)] + if self.is_failover_possible(members): + logger.info("Master crashed. Failing over.") + self.demote('immediate') + return 'stopped PostgreSQL to fail over after a crash' + else: + timeout = None + + data = self.state_handler.controldata() + logger.info('pg_controldata:\n%s\n', '\n'.join(' {0}: {1}'.format(k, v) for k, v in data.items())) + if data.get('Database cluster state') in ('in production', 'shutting down', 'in crash recovery') \ + and not self._crash_recovery_executed and \ + (self.cluster.is_unlocked() or self._rewind.can_rewind): + self._crash_recovery_executed = True + self._crash_recovery_started = time.time() + msg = 'doing crash recovery in a single user mode' + return self._async_executor.try_run_async(msg, self._rewind.ensure_clean_shutdown) or msg + + self.load_cluster_from_dcs() + + role = 'replica' + if self.is_standby_cluster() or not self.has_lock(): + if not self._rewind.executed: + self._rewind.trigger_check_diverged_lsn() + msg = self._handle_rewind_or_reinitialize() + if msg: + return msg + + if self.has_lock(): # in standby cluster + msg = "starting as a standby leader because i had the session lock" + role = 'standby_leader' + node_to_follow = self._get_node_to_follow(self.cluster) + elif self.is_standby_cluster() and self.cluster.is_unlocked(): + msg = "trying to follow a remote master because standby cluster is unhealthy" + node_to_follow = self.get_remote_master() + else: + msg = "starting as a secondary" + node_to_follow = self._get_node_to_follow(self.cluster) + elif self.has_lock(): + msg = "starting as readonly because i had the session lock" + node_to_follow = None + lock_owner = self.cluster.leader and self.cluster.leader.name + is_standby = False + if lock_owner != None and lock_owner != self.state_handler.name: + logger.info("starting as standby") + is_standby = True + if self._async_executor.try_run_async('restarting after failure', self.state_handler.follow, + args=(node_to_follow, role, timeout, False, is_standby)) is None: + self.recovering = True + return msg + + def _get_node_to_follow(self, cluster): + # determine the node to follow. If replicatefrom tag is set, + # try to follow the node mentioned there, otherwise, follow the leader. + standby_config = self.get_standby_cluster_config() + is_standby_cluster = _is_standby_cluster(standby_config) + if is_standby_cluster and (self.cluster.is_unlocked() or self.has_lock(False)): + node_to_follow = self.get_remote_master() + elif self.patroni.replicatefrom and self.patroni.replicatefrom != self.state_handler.name: + node_to_follow = cluster.get_member(self.patroni.replicatefrom) + else: + node_to_follow = cluster.leader if cluster.leader and cluster.leader.name else None + + node_to_follow = node_to_follow if node_to_follow and node_to_follow.name != self.state_handler.name else None + + if node_to_follow and not isinstance(node_to_follow, RemoteMember): + # we are going to abuse Member.data to pass following parameters + params = ('restore_command', 'archive_cleanup_command') + for param in params: # It is highly unlikely to happen, but we want to protect from the case + node_to_follow.data.pop(param, None) # when above-mentioned params came from outside. + if is_standby_cluster: + node_to_follow.data.update({p: standby_config[p] for p in params if standby_config.get(p)}) + + return node_to_follow + + def follow(self, demote_reason, follow_reason, refresh=True): + if refresh: + self.load_cluster_from_dcs() + + is_leader = self.state_handler.is_leader() + + node_to_follow = self._get_node_to_follow(self.cluster) + + if self.is_paused(): + if not (self._rewind.is_needed and self._rewind.can_rewind_or_reinitialize_allowed)\ + or self.cluster.is_unlocked(): + if is_leader: + self.state_handler.set_role('master') + return 'continue to run as master without lock' + elif self.state_handler.role != 'standby_leader': + self.state_handler.set_role('replica') + + if not node_to_follow: + return 'no action' + elif is_leader: + self.demote('immediate-nolock') + return demote_reason + + if self.is_standby_cluster() and self._leader_timeline and \ + self.state_handler.get_history(self._leader_timeline + 1): + self._rewind.trigger_check_diverged_lsn() + + msg = self._handle_rewind_or_reinitialize() + if msg: + return msg + + role = 'standby_leader' if isinstance(node_to_follow, RemoteMember) and self.has_lock(False) else 'replica' + # It might happen that leader key in the standby cluster references non-exiting member. + # In this case it is safe to continue running without changing recovery.conf + if self.is_standby_cluster() and role == 'replica' and not (node_to_follow and node_to_follow.conn_url): + return 'continue following the old known standby leader' + else: + action = self.state_handler.config.check_db_state() + if action == 'restart': + if role == 'replica': + logger.info('the replica is abnormal and restarting') + self._async_executor.try_run_async('the replica is abnormal and restarting', + self.state_handler.follow, args=(node_to_follow, role, None, False, True)) + else: + logger.info('the database is abnormal and restarting') + self._async_executor.try_run_async('the database is abnormal and restarting', + self.state_handler.follow, args=(node_to_follow, role)) + self._rewind.trigger_check_diverged_lsn() + elif action == 'rebuild': + logger.info('the standby is abnormal and rebuilding') + self.state_handler.rebuild() + time.sleep(5) + status, output = self.state_handler.gs_query() + logger.info(output) + elif role == 'standby_leader' and self.state_handler.role != role: + self.state_handler.set_role(role) + self.state_handler.call_nowait(ACTION_ON_ROLE_CHANGE) + + return follow_reason + + def is_synchronous_mode(self): + return self.check_mode('synchronous_mode') + + def is_synchronous_mode_strict(self): + return self.check_mode('synchronous_mode_strict') + + def process_sync_replication(self): + """Process synchronous standby beahvior. + + Synchronous standbys are registered in two places postgresql.conf and DCS. The order of updating them must + be right. The invariant that should be kept is that if a node is master and sync_standby is set in DCS, + then that node must have synchronous_standby set to that value. Or more simple, first set in postgresql.conf + and then in DCS. When removing, first remove in DCS, then in postgresql.conf. This is so we only consider + promoting standbys that were guaranteed to be replicating synchronously. + """ + if self.is_synchronous_mode(): + sync_node_count = self.patroni.config['synchronous_node_count'] + current = self.cluster.sync.leader and self.cluster.sync.members or [] + picked, allow_promote = self.state_handler.pick_synchronous_standby(self.cluster, sync_node_count, + self.patroni.config[ + 'maximum_lag_on_syncnode']) + if set(picked) != set(current): + # update synchronous standby list in dcs temporarily to point to common nodes in current and picked + sync_common = list(set(current).intersection(set(allow_promote))) + if set(sync_common) != set(current): + logger.info("Updating synchronous privilege temporarily from %s to %s", current, sync_common) + if not self.dcs.write_sync_state(self.state_handler.name, + sync_common or None, + index=self.cluster.sync.index): + logger.info('Synchronous replication key updated by someone else.') + return + + # Update db param and wait for x secs + if self.is_synchronous_mode_strict() and not picked: + picked = ['*'] + logger.warning("No standbys available!") + + logger.info("Assigning synchronous standby status to %s", picked) + self.state_handler.config.set_synchronous_standby(picked) + + if picked and picked[0] != '*' and set(allow_promote) != set(picked) and not allow_promote: + # Wait for PostgreSQL to enable synchronous mode and see if we can immediately set sync_standby + time.sleep(2) + _, allow_promote = self.state_handler.pick_synchronous_standby(self.cluster, + sync_node_count, + self.patroni.config[ + 'maximum_lag_on_syncnode']) + if allow_promote and set(allow_promote) != set(sync_common): + try: + cluster = self.dcs.get_cluster() + except DCSError: + return logger.warning("Could not get cluster state from DCS during process_sync_replication()") + if cluster.sync.leader and cluster.sync.leader != self.state_handler.name: + logger.info("Synchronous replication key updated by someone else") + return + if not self.dcs.write_sync_state(self.state_handler.name, allow_promote, index=cluster.sync.index): + logger.info("Synchronous replication key updated by someone else") + return + logger.info("Synchronous standby status assigned to %s", allow_promote) + else: + if self.cluster.sync.leader and self.dcs.delete_sync_state(index=self.cluster.sync.index): + logger.info("Disabled synchronous replication") + self.state_handler.config.set_synchronous_standby([]) + + def is_sync_standby(self, cluster): + return cluster.leader and cluster.sync.leader == cluster.leader.name \ + and self.state_handler.name in cluster.sync.members + + def while_not_sync_standby(self, func): + """Runs specified action while trying to make sure that the node is not assigned synchronous standby status. + + Tags us as not allowed to be a sync standby as we are going to go away, if we currently are wait for + leader to notice and pick an alternative one or if the leader changes or goes away we are also free. + + If the connection to DCS fails we run the action anyway, as this is only a hint. + + There is a small race window where this function runs between a master picking us the sync standby and + publishing it to the DCS. As the window is rather tiny consequences are holding up commits for one cycle + period we don't worry about it here.""" + + if not self.is_synchronous_mode() or self.patroni.nosync: + return func() + + with self._member_state_lock: + self._disable_sync += 1 + try: + if self.touch_member(): + # Master should notice the updated value during the next cycle. We will wait double that, if master + # hasn't noticed the value by then not disabling sync replication is not likely to matter. + for _ in polling_loop(timeout=self.dcs.loop_wait*2, interval=2): + try: + if not self.is_sync_standby(self.dcs.get_cluster()): + break + except DCSError: + logger.warning("Could not get cluster state, skipping synchronous standby disable") + break + logger.info("Waiting for master to release us from synchronous standby") + else: + logger.warning("Updating member state failed, skipping synchronous standby disable") + + return func() + finally: + with self._member_state_lock: + self._disable_sync -= 1 + + def update_cluster_history(self): + master_timeline = self.state_handler.get_master_timeline() + cluster_history = self.cluster.history and self.cluster.history.lines + if master_timeline == 1: + if cluster_history: + self.dcs.set_history_value('[]') + elif not cluster_history or cluster_history[-1][0] != master_timeline - 1 or len(cluster_history[-1]) != 4: + cluster_history = {line[0]: line for line in cluster_history or []} + history = self.state_handler.get_history(master_timeline) + if history and self.cluster.config: + history = history[-self.cluster.config.max_timelines_history:] + for line in history: + # enrich current history with promotion timestamps stored in DCS + if len(line) == 3 and line[0] in cluster_history \ + and len(cluster_history[line[0]]) == 4 \ + and cluster_history[line[0]][1] == line[1]: + line.append(cluster_history[line[0]][3]) + self.dcs.set_history_value(json.dumps(history, separators=(',', ':'))) + + def enforce_follow_remote_master(self, message): + demote_reason = 'cannot be a real master in standby cluster' + return self.follow(demote_reason, message) + + def check_and_repair_leader(self): + """ + check the state of leader, if it is not normal, restart to repair it + """ + action = '' + retry_times = 3 + while action != 'normal' and retry_times > 0: + action = self.state_handler.config.check_db_state(is_leader=True) + retry_times -= 1 + time.sleep(1) + if action == 'restart': + logger.info('leader is abnormal and repairing') + self.state_handler.pg_ctl('stop') + self.state_handler.pg_ctl('start') + time.sleep(1) + status, output = self.state_handler.gs_query() + logger.info(output) + + def enforce_master_role(self, message, promote_message): + """ + Ensure the node that has won the race for the leader key meets criteria + for promoting its PG server to the 'master' role. + """ + if not self.is_paused(): + if not self.watchdog.is_running and not self.watchdog.activate(): + if self.state_handler.is_leader(): + self.demote('immediate') + return 'Demoting self because watchdog could not be activated' + else: + self.release_leader_key_voluntarily() + return 'Not promoting self because watchdog could not be activated' + + with self._async_response: + if self._async_response.result is False: + logger.warning("Releasing the leader key voluntarily because the pre-promote script failed") + self._released_leader_key_timestamp = time.time() + self.release_leader_key_voluntarily() + # discard the result of the failed pre-promote script to be able to re-try promote + self._async_response.reset() + return 'Promotion cancelled because the pre-promote script failed' + + if self.state_handler.is_leader(): + # Inform the state handler about its master role. + # It may be unaware of it if postgres is promoted manually. + logger.info('enforce_master_role is_leader') + self.check_and_repair_leader() + self.state_handler.set_role('master') + self.process_sync_replication() + self.update_cluster_history() + return message + elif self.state_handler.role == 'master': + logger.info('enforce_master_role master') + self.check_and_repair_leader() + self.process_sync_replication() + return message + else: + if self.is_synchronous_mode(): + # Just set ourselves as the authoritative source of truth for now. We don't want to wait for standbys + # to connect. We will try finding a synchronous standby in the next cycle. + if not self.dcs.write_sync_state(self.state_handler.name, None, index=self.cluster.sync.index): + # Somebody else updated sync state, it may be due to us losing the lock. To be safe, postpone + # promotion until next cycle. TODO: trigger immediate retry of run_cycle + return 'Postponing promotion because synchronous replication state was updated by somebody else' + self.state_handler.config.set_synchronous_standby(['*'] if self.is_synchronous_mode_strict() else []) + if self.state_handler.role != 'master': + self.set_leader_access_is_restricted(self.cluster.has_permanent_logical_slots(self.state_handler.name)) + + def on_success(): + self._rewind.reset_state() + logger.info("cleared rewind state after becoming the leader") + + with self._async_response: + self._async_response.reset() + self._async_executor.try_run_async('promote', self.state_handler.promote, + args=(self.dcs.loop_wait, self._async_response, on_success, + self._leader_access_is_restricted)) + return promote_message + + def fetch_node_status(self, member): + """This function perform http get request on member.api_url and fetches its status + :returns: `_MemberStatus` object + """ + + try: + response = self.patroni.request(member, timeout=2, retries=0) + data = response.data.decode('utf-8') + logger.info('Got response from %s %s: %s', member.name, member.api_url, data) + return _MemberStatus.from_api_response(member, json.loads(data)) + except Exception as e: + logger.warning("Request failed to %s: GET %s (%s)", member.name, member.api_url, e) + return _MemberStatus.unknown(member) + + def fetch_nodes_statuses(self, members): + pool = ThreadPool(len(members)) + results = pool.map(self.fetch_node_status, members) # Run API calls on members in parallel + pool.close() + pool.join() + return results + + def is_lagging(self, wal_position): + """Returns if instance with an wal should consider itself unhealthy to be promoted due to replication lag. + + :param wal_position: Current wal position. + :returns True when node is lagging + """ + lag = (self.cluster.last_leader_operation or 0) - wal_position + return lag > self.patroni.config.get('maximum_lag_on_failover', 0) + + def _is_healthiest_node(self, members, check_replication_lag=True): + """This method tries to determine whether I am healthy enough to became a new leader candidate or not.""" + + # We don't call `last_operation()` here because it returns a string + _, my_wal_position, _ = self.state_handler.timeline_wal_position() + if check_replication_lag and self.is_lagging(my_wal_position): + logger.info('My wal position exceeds maximum replication lag') + return False # Too far behind last reported wal position on master + + if not self.is_standby_cluster() and self.check_timeline(): + cluster_timeline = self.cluster.timeline + my_timeline = self.state_handler.replica_cached_timeline(cluster_timeline) + if my_timeline < cluster_timeline: + logger.info('My timeline %s is behind last known cluster timeline %s', my_timeline, cluster_timeline) + return False + + # Prepare list of nodes to run check against + members = [m for m in members if m.name != self.state_handler.name and not m.nofailover and m.api_url] + + if members: + for st in self.fetch_nodes_statuses(members): + if st.failover_limitation() is None: + if not st.in_recovery: + logger.warning('Master (%s) is still alive', st.member.name) + return False + if my_wal_position < st.wal_position: + logger.info('Wal position of %s is ahead of my wal position', st.member.name) + # In synchronous mode the former leader might be still accessible and even be ahead of us. + # We should not disqualify himself from the leader race in such a situation. + if not self.is_synchronous_mode() or st.member.name != self.cluster.sync.leader: + return False + logger.info('Ignoring the former leader being ahead of us') + return True + + def is_failover_possible(self, members): + ret = False + cluster_timeline = self.cluster.timeline + members = [m for m in members if m.name != self.state_handler.name and not m.nofailover and m.api_url] + if members: + for st in self.fetch_nodes_statuses(members): + not_allowed_reason = st.failover_limitation() + if not_allowed_reason: + logger.info('Member %s is %s', st.member.name, not_allowed_reason) + elif self.is_lagging(st.wal_position): + logger.info('Member %s exceeds maximum replication lag', st.member.name) + elif self.check_timeline() and (not st.timeline or st.timeline < cluster_timeline): + logger.info('Timeline %s of member %s is behind the cluster timeline %s', + st.timeline, st.member.name, cluster_timeline) + else: + ret = True + else: + logger.warning('manual failover: members list is empty') + return ret + + def manual_failover_process_no_leader(self): + failover = self.cluster.failover + if failover.candidate: # manual failover to specific member + if failover.candidate == self.state_handler.name: # manual failover to me + return True + elif self.is_paused(): + # Remove failover key if the node to failover has terminated to avoid waiting for it indefinitely + # In order to avoid attempts to delete this key from all nodes only the master is allowed to do it. + if (not self.cluster.get_member(failover.candidate, fallback_to_leader=False) and + self.state_handler.is_leader()): + logger.warning("manual failover: removing failover key because failover candidate is not running") + self.dcs.manual_failover('', '', index=self.cluster.failover.index) + return None + return False + + # find specific node and check that it is healthy + member = self.cluster.get_member(failover.candidate, fallback_to_leader=False) + if member: + st = self.fetch_node_status(member) + not_allowed_reason = st.failover_limitation() + if not_allowed_reason is None: # node is healthy + logger.info('manual failover: to %s, i am %s', st.member.name, self.state_handler.name) + return False + # we wanted to failover to specific member but it is not healthy + logger.warning('manual failover: member %s is %s', st.member.name, not_allowed_reason) + + # at this point we should consider all members as a candidates for failover + # i.e. we assume that failover.candidate is None + elif self.is_paused(): + return False + + # try to pick some other members to failover and check that they are healthy + if failover.leader: + if self.state_handler.name == failover.leader: # I was the leader + # exclude me and desired member which is unhealthy (failover.candidate can be None) + members = [m for m in self.cluster.members if m.name not in (failover.candidate, failover.leader)] + if self.is_failover_possible(members): # check that there are healthy members + return False + else: # I was the leader and it looks like currently I am the only healthy member + return True + + # at this point we assume that our node is a candidate for a failover among all nodes except former leader + + # exclude former leader from the list (failover.leader can be None) + members = [m for m in self.cluster.members if m.name != failover.leader] + return self._is_healthiest_node(members, check_replication_lag=False) + + def is_healthiest_node(self): + if time.time() - self._released_leader_key_timestamp < self.dcs.ttl: + logger.info('backoff: skip leader race after pre_promote script failure and releasing the lock voluntarily') + return False + + if self.is_paused() and not self.patroni.nofailover and \ + self.cluster.failover and not self.cluster.failover.scheduled_at: + ret = self.manual_failover_process_no_leader() + if ret is not None: # continue if we just deleted the stale failover key as a master + return ret + + if self.state_handler.is_starting(): # postgresql still starting up is unhealthy + return False + + if self.state_handler.is_leader(): + # in pause leader is the healthiest only when no initialize or sysid matches with initialize! + return not self.is_paused() or not self.cluster.initialize\ + or self.state_handler.sysid == self.cluster.initialize + + if self.is_paused(): + return False + + if self.patroni.nofailover: # nofailover tag makes node always unhealthy + return False + + if self.cluster.failover: + return self.manual_failover_process_no_leader() + + if not self.watchdog.is_healthy: + logger.warning('Watchdog device is not usable') + return False + + # When in sync mode, only last known master and sync standby are allowed to promote automatically. + all_known_members = self.cluster.members + self.old_cluster.members + if self.is_synchronous_mode() and self.cluster.sync.leader: + if not self.cluster.sync.matches(self.state_handler.name): + return False + # pick between synchronous candidates so we minimize unnecessary failovers/demotions + members = {m.name: m for m in all_known_members if self.cluster.sync.matches(m.name)} + else: + # run usual health check + members = {m.name: m for m in all_known_members} + + return self._is_healthiest_node(members.values()) + + def _delete_leader(self, last_operation=None): + self.set_is_leader(False) + self.dcs.delete_leader(last_operation) + self.dcs.reset_cluster() + + def release_leader_key_voluntarily(self, last_operation=None): + self._delete_leader(last_operation) + self.touch_member() + logger.info("Leader key released") + + def demote(self, mode): + """Demote PostgreSQL running as master. + + :param mode: One of offline, graceful or immediate. + offline is used when connection to DCS is not available. + graceful is used when failing over to another node due to user request. May only be called running async. + immediate is used when we determine that we are not suitable for master and want to failover quickly + without regard for data durability. May only be called synchronously. + immediate-nolock is used when find out that we have lost the lock to be master. Need to bring down + PostgreSQL as quickly as possible without regard for data durability. May only be called synchronously. + """ + mode_control = { + 'offline': dict(stop='fast', checkpoint=False, release=False, offline=True, async_req=False), + 'graceful': dict(stop='fast', checkpoint=True, release=True, offline=False, async_req=False), + 'immediate': dict(stop='immediate', checkpoint=False, release=True, offline=False, async_req=True), + 'immediate-nolock': dict(stop='immediate', checkpoint=False, release=False, offline=False, async_req=True), + }[mode] + + self._rewind.trigger_check_diverged_lsn() + self.state_handler.stop(mode_control['stop'], checkpoint=mode_control['checkpoint'], + on_safepoint=self.watchdog.disable if self.watchdog.is_running else None, + stop_timeout=self.master_stop_timeout()) + self.state_handler.set_role('demoted') + self.set_is_leader(False) + + if mode_control['release']: + checkpoint_location = self.state_handler.latest_checkpoint_location() if mode == 'graceful' else None + with self._async_executor: + self.release_leader_key_voluntarily(checkpoint_location) + time.sleep(2) # Give a time to somebody to take the leader lock + if mode_control['offline']: + node_to_follow, leader = None, None + else: + cluster = self.dcs.get_cluster() + node_to_follow, leader = self._get_node_to_follow(cluster), cluster.leader + + # FIXME: with mode offline called from DCS exception handler and handle_long_action_in_progress + # there could be an async action already running, calling follow from here will lead + # to racy state handler state updates. + if mode_control['async_req']: + self._async_executor.try_run_async('starting after demotion', + self.state_handler.follow, args=(node_to_follow, 'replica', None, False, True)) + else: + if self.is_synchronous_mode(): + self.state_handler.config.set_synchronous_standby([]) + if self._rewind.rewind_or_reinitialize_needed_and_possible(leader): + return False # do not start postgres, but run pg_rewind on the next iteration + self.state_handler.follow(node_to_follow, is_standby=True) + + def should_run_scheduled_action(self, action_name, scheduled_at, cleanup_fn): + if scheduled_at and not self.is_paused(): + # If the scheduled action is in the far future, we shouldn't do anything and just return. + # If the scheduled action is in the past, we consider the value to be stale and we remove + # the value. + # If the value is close to now, we initiate the scheduled action + # Additionally, if the scheduled action cannot be executed altogether, i.e. there is an error + # or the action is in the past - we take care of cleaning it up. + now = datetime.datetime.now(tzutc) + try: + delta = (scheduled_at - now).total_seconds() + + if delta > self.dcs.loop_wait: + logger.info('Awaiting %s at %s (in %.0f seconds)', + action_name, scheduled_at.isoformat(), delta) + return False + elif delta < - int(self.dcs.loop_wait * 1.5): + # This means that if run_cycle gets delayed for 2.5x loop_wait we skip the + # scheduled action. Probably not a problem, if things are that bad we don't + # want to be restarting or failing over anyway. + logger.warning('Found a stale %s value, cleaning up: %s', + action_name, scheduled_at.isoformat()) + cleanup_fn() + return False + + # The value is very close to now + time.sleep(max(delta, 0)) + logger.info('Manual scheduled {0} at %s'.format(action_name), scheduled_at.isoformat()) + return True + except TypeError: + logger.warning('Incorrect value of scheduled_at: %s', scheduled_at) + cleanup_fn() + return False + + def process_manual_failover_from_leader(self): + """Checks if manual failover is requested and takes action if appropriate. + + Cleans up failover key if failover conditions are not matched. + + :returns: action message if demote was initiated, None if no action was taken""" + failover = self.cluster.failover + if not failover or (self.is_paused() and not self.state_handler.is_leader()): + return + + if (failover.scheduled_at and not + self.should_run_scheduled_action("failover", failover.scheduled_at, lambda: + self.dcs.manual_failover('', '', index=failover.index))): + return + + if not failover.leader or failover.leader == self.state_handler.name: + if not failover.candidate or failover.candidate != self.state_handler.name: + if not failover.candidate and self.is_paused(): + logger.warning('Failover is possible only to a specific candidate in a paused state') + else: + if self.is_synchronous_mode(): + if failover.candidate and not self.cluster.sync.matches(failover.candidate): + logger.warning('Failover candidate=%s does not match with sync_standbys=%s', + failover.candidate, self.cluster.sync.sync_standby) + members = [] + else: + members = [m for m in self.cluster.members if self.cluster.sync.matches(m.name)] + else: + members = [m for m in self.cluster.members + if not failover.candidate or m.name == failover.candidate] + if self.is_failover_possible(members): # check that there are healthy members + ret = self._async_executor.try_run_async('manual failover: demote', self.demote, ('graceful',)) + return ret or 'manual failover: demoting myself' + else: + logger.warning('manual failover: no healthy members found, failover is not possible') + else: + logger.warning('manual failover: I am already the leader, no need to failover') + else: + logger.warning('manual failover: leader name does not match: %s != %s', + failover.leader, self.state_handler.name) + + logger.info('Cleaning up failover key') + self.dcs.manual_failover('', '', index=failover.index) + + def process_unhealthy_cluster(self): + """Cluster has no leader key""" + + if self.is_healthiest_node(): + if self.acquire_lock(): + failover = self.cluster.failover + if failover: + if self.is_paused() and failover.leader and failover.candidate: + logger.info('Updating failover key after acquiring leader lock...') + self.dcs.manual_failover('', failover.candidate, failover.scheduled_at, failover.index) + else: + logger.info('Cleaning up failover key after acquiring leader lock...') + self.dcs.manual_failover('', '') + self.load_cluster_from_dcs() + + if self.is_standby_cluster(): + # standby leader disappeared, and this is the healthiest + # replica, so it should become a new standby leader. + # This implies we need to start following a remote master + msg = 'promoted self to a standby leader by acquiring session lock' + return self.enforce_follow_remote_master(msg) + else: + return self.enforce_master_role( + 'acquired session lock as a leader', + 'promoted self to leader by acquiring session lock' + ) + else: + return self.follow('demoted self after trying and failing to obtain lock', + 'following new leader after trying and failing to obtain lock') + else: + # when we are doing manual failover there is no guaranty that new leader is ahead of any other node + # node tagged as nofailover can be ahead of the new leader either, but it is always excluded from elections + if bool(self.cluster.failover) or self.patroni.nofailover: + self._rewind.trigger_check_diverged_lsn() + time.sleep(2) # Give a time to somebody to take the leader lock + + if self.patroni.nofailover: + return self.follow('demoting self because I am not allowed to become master', + 'following a different leader because I am not allowed to promote') + return self.follow('demoting self because i am not the healthiest node', + 'following a different leader because i am not the healthiest node') + + def process_healthy_cluster(self): + if self.has_lock(): + if self.is_paused() and not self.state_handler.is_leader(): + if self.cluster.failover and self.cluster.failover.candidate == self.state_handler.name: + return 'waiting to become master after promote...' + + self._delete_leader() + return 'removed leader lock because postgres is not running as master' + + if self.state_handler.is_leader() and self._leader_access_is_restricted: + self.state_handler.slots_handler.sync_replication_slots(self.cluster) + self.state_handler.call_nowait(ACTION_ON_ROLE_CHANGE) + self.set_leader_access_is_restricted(False) + + if self.update_lock(True): + msg = self.process_manual_failover_from_leader() + if msg is not None: + return msg + + # check if the node is ready to be used by pg_rewind + self._rewind.ensure_checkpoint_after_promote(self.wakeup) + + if self.is_standby_cluster(): + # in case of standby cluster we don't really need to + # enforce anything, since the leader is not a master. + # So just remind the role. + msg = 'no action. i am the standby leader with the lock' \ + if self.state_handler.role == 'standby_leader' else \ + 'promoted self to a standby leader because i had the session lock' + return self.enforce_follow_remote_master(msg) + else: + return self.enforce_master_role( + 'no action. i am the leader with the lock', + 'promoted self to leader because i had the session lock' + ) + else: + # Either there is no connection to DCS or someone else acquired the lock + logger.error('failed to update leader lock') + if self.state_handler.is_leader(): + if self.is_paused(): + return 'continue to run as master after failing to update leader lock in DCS' + self.demote('immediate-nolock') + return 'demoted self because failed to update leader lock in DCS' + else: + return 'not promoting because failed to update leader lock in DCS' + else: + logger.info('does not have lock') + if self.is_standby_cluster(): + return self.follow('cannot be a real master in standby cluster', + 'no action. i am a secondary and i am following a standby leader', refresh=False) + return self.follow('demoting self because i do not have the lock and i was a leader', + 'no action. i am a secondary and i am following a leader', refresh=False) + + def evaluate_scheduled_restart(self): + if self._async_executor.busy: # Restart already in progress + return None + + # restart if we need to + restart_data = self.future_restart_scheduled() + if restart_data: + recent_time = self.state_handler.postmaster_start_time() + request_time = restart_data['postmaster_start_time'] + # check if postmaster start time has changed since the last restart + if recent_time and request_time and recent_time != request_time: + logger.info("Cancelling scheduled restart: postgres restart has already happened at %s", recent_time) + self.delete_future_restart() + return None + + if (restart_data and + self.should_run_scheduled_action('restart', restart_data['schedule'], self.delete_future_restart)): + try: + ret, message = self.restart(restart_data, run_async=True) + if not ret: + logger.warning("Scheduled restart: %s", message) + return None + return message + finally: + self.delete_future_restart() + + def restart_matches(self, role, postgres_version, pending_restart): + reason_to_cancel = "" + # checking the restart filters here seem to be less ugly than moving them into the + # run_scheduled_action. + if role and role != self.state_handler.role: + reason_to_cancel = "host role mismatch" + + if postgres_version and postgres_version_to_int(postgres_version) <= int(self.state_handler.server_version): + reason_to_cancel = "postgres version mismatch" + + if pending_restart and not self.state_handler.pending_restart: + reason_to_cancel = "pending restart flag is not set" + + if not reason_to_cancel: + return True + else: + logger.info("not proceeding with the restart: %s", reason_to_cancel) + return False + + def schedule_future_restart(self, restart_data): + with self._async_executor: + restart_data['postmaster_start_time'] = self.state_handler.postmaster_start_time() + if not self.patroni.scheduled_restart: + self.patroni.scheduled_restart = restart_data + self.touch_member() + return True + return False + + def delete_future_restart(self): + ret = False + with self._async_executor: + if self.patroni.scheduled_restart: + self.patroni.scheduled_restart = {} + self.touch_member() + ret = True + return ret + + def future_restart_scheduled(self): + return self.patroni.scheduled_restart.copy() if (self.patroni.scheduled_restart and + isinstance(self.patroni.scheduled_restart, dict)) else None + + def restart_scheduled(self): + return self._async_executor.scheduled_action == 'restart' + + def restart(self, restart_data, run_async=False): + """ conditional and unconditional restart """ + assert isinstance(restart_data, dict) + + if (not self.restart_matches(restart_data.get('role'), + restart_data.get('postgres_version'), + ('restart_pending' in restart_data))): + return (False, "restart conditions are not satisfied") + + with self._async_executor: + prev = self._async_executor.schedule('restart') + if prev is not None: + return (False, prev + ' already in progress') + + # Make the main loop to think that we were recovering dead postgres. If we fail + # to start postgres after a specified timeout (see below), we need to remove + # leader key (if it belong to us) rather than trying to start postgres once again. + self.recovering = True + + # Now that restart is scheduled we can set timeout for startup, it will get reset + # once async executor runs and main loop notices PostgreSQL as up. + timeout = restart_data.get('timeout', self.patroni.config['master_start_timeout']) + self.set_start_timeout(timeout) + + # For non async cases we want to wait for restart to complete or timeout before returning. + do_restart = functools.partial(self.state_handler.restart, timeout, self._async_executor.critical_task) + if self.is_synchronous_mode() and not self.has_lock(): + do_restart = functools.partial(self.while_not_sync_standby, do_restart) + + if run_async: + self._async_executor.run_async(do_restart) + return (True, 'restart initiated') + else: + res = self._async_executor.run(do_restart) + if res: + return (True, 'restarted successfully') + elif res is None: + return (False, 'postgres is still starting') + else: + return (False, 'restart failed') + + def _do_reinitialize(self, cluster): + self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout']) + # Commented redundant data directory cleanup here + # self.state_handler.remove_data_directory() + + clone_member = self.cluster.get_clone_member(self.state_handler.name) + member_role = 'leader' if clone_member == self.cluster.leader else 'replica' + return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name)) + + def reinitialize(self, force=False): + with self._async_executor: + self.load_cluster_from_dcs() + + if self.cluster.is_unlocked(): + return 'Cluster has no leader, can not reinitialize' + + if self.has_lock(False): + return 'I am the leader, can not reinitialize' + + if force: + self._async_executor.cancel() + + with self._async_executor: + action = self._async_executor.schedule('reinitialize') + if action is not None: + return '{0} already in progress'.format(action) + + self._async_executor.run_async(self._do_reinitialize, args=(self.cluster, )) + + def handle_long_action_in_progress(self): + """ + Figure out what to do with the task AsyncExecutor is performing. + """ + if self.has_lock() and self.update_lock(): + if self._async_executor.scheduled_action == 'doing crash recovery in a single user mode': + time_left = self.patroni.config['master_start_timeout'] - (time.time() - self._crash_recovery_started) + if time_left <= 0 and self.is_failover_possible(self.cluster.members): + logger.info("Demoting self because crash recovery is taking too long") + self.state_handler.cancellable.cancel(True) + self.demote('immediate') + return 'terminated crash recovery because of startup timeout' + + return 'updated leader lock during ' + self._async_executor.scheduled_action + elif not self.state_handler.bootstrapping and not self.is_paused(): + # Don't have lock, make sure we are not promoting or starting up a master in the background + if self._async_executor.scheduled_action == 'promote': + with self._async_response: + cancel = self._async_response.cancel() + if cancel: + self.state_handler.cancellable.cancel() + return 'lost leader before promote' + + if self.state_handler.role == 'master': + logger.info("Demoting master during " + self._async_executor.scheduled_action) + if self._async_executor.scheduled_action == 'restart': + # Restart needs a special interlocking cancel because postmaster may be just started in a + # background thread and has not even written a pid file yet. + with self._async_executor.critical_task as task: + if not task.cancel(): + self.state_handler.terminate_starting_postmaster(postmaster=task.result) + self.demote('immediate-nolock') + return 'lost leader lock during ' + self._async_executor.scheduled_action + if self.cluster.is_unlocked(): + logger.info('not healthy enough for leader race') + + return self._async_executor.scheduled_action + ' in progress' + + @staticmethod + def sysid_valid(sysid): + # sysid does tv_sec << 32, where tv_sec is the number of seconds sine 1970, + # so even 1 << 32 would have 10 digits. + sysid = str(sysid) + return len(sysid) >= 10 and sysid.isdigit() + + def post_recover(self): + if not self.state_handler.is_running(): + self.watchdog.disable() + if self.has_lock(): + if self.state_handler.role in ('master', 'standby_leader'): + self.state_handler.set_role('demoted') + self._delete_leader() + return 'removed leader key after trying and failing to start postgres' + return 'failed to start postgres' + self._crash_recovery_executed = False + if self._rewind.executed and not self._rewind.failed: + self._rewind.reset_state() + return None + + def cancel_initialization(self): + logger.info('removing initialize key after failed attempt to bootstrap the cluster') + self.dcs.cancel_initialization() + self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout']) + self.state_handler.move_data_directory() + raise PatroniFatalException('Failed to bootstrap cluster') + + def post_bootstrap(self): + with self._async_response: + result = self._async_response.result + # bootstrap has failed if postgres is not running + if not self.state_handler.is_running() or result is False: + self.cancel_initialization() + + if result is None: + if not self.state_handler.is_leader(): + return 'waiting for end of recovery after bootstrap' + + self.state_handler.set_role('master') + ret = self._async_executor.try_run_async('post_bootstrap', self.state_handler.bootstrap.post_bootstrap, + args=(self.patroni.config['bootstrap'], self._async_response)) + return ret or 'running post_bootstrap' + + self.state_handler.bootstrapping = False + if not self.watchdog.activate(): + logger.error('Cancelling bootstrap because watchdog activation failed') + self.cancel_initialization() + self.dcs.initialize(create_new=(self.cluster.initialize is None), sysid=self.state_handler.sysid) + self.dcs.set_config_value(json.dumps(self.patroni.config.dynamic_configuration, separators=(',', ':'))) + self.state_handler.slots_handler.sync_replication_slots(self.cluster) + self.dcs.take_leader() + self.set_is_leader(True) + self.state_handler.call_nowait(ACTION_ON_START) + self.load_cluster_from_dcs() + + return 'initialized a new cluster' + + def handle_starting_instance(self): + """Starting up PostgreSQL may take a long time. In case we are the leader we may want to + fail over to.""" + + # Check if we are in startup, when paused defer to main loop for manual failovers. + if not self.state_handler.check_for_startup() or self.is_paused(): + self.set_start_timeout(None) + if self.is_paused(): + self.state_handler.set_state(self.state_handler.is_running() and 'running' or 'stopped') + return None + + # state_handler.state == 'starting' here + if self.has_lock(): + if not self.update_lock(): + logger.info("Lost lock while starting up. Demoting self.") + self.demote('immediate-nolock') + return 'stopped PostgreSQL while starting up because leader key was lost' + + timeout = self._start_timeout or self.patroni.config['master_start_timeout'] + time_left = timeout - self.state_handler.time_in_state() + + if time_left <= 0: + if self.is_failover_possible(self.cluster.members): + logger.info("Demoting self because master startup is taking too long") + self.demote('immediate') + return 'stopped PostgreSQL because of startup timeout' + else: + return 'master start has timed out, but continuing to wait because failover is not possible' + else: + msg = self.process_manual_failover_from_leader() + if msg is not None: + return msg + + return 'PostgreSQL is still starting up, {0:.0f} seconds until timeout'.format(time_left) + else: + # Use normal processing for standbys + logger.info("Still starting up as a standby.") + return None + + def set_start_timeout(self, value): + """Sets timeout for starting as master before eligible for failover. + + Must be called when async_executor is busy or in the main thread.""" + self._start_timeout = value + + def _run_cycle(self): + dcs_failed = False + try: + self.state_handler.reset_cluster_info_state() + self.load_cluster_from_dcs() + + if self.is_paused(): + self.watchdog.disable() + self._was_paused = True + else: + if self._was_paused: + self.state_handler.schedule_sanity_checks_after_pause() + self._was_paused = False + + if not self.cluster.has_member(self.state_handler.name): + self.touch_member() + + # cluster has leader key but not initialize key + if not (self.cluster.is_unlocked() or self.sysid_valid(self.cluster.initialize)) and self.has_lock(): + self.dcs.initialize(create_new=(self.cluster.initialize is None), sysid=self.state_handler.sysid) + + if not (self.cluster.is_unlocked() or self.cluster.config and self.cluster.config.data) and self.has_lock(): + self.dcs.set_config_value(json.dumps(self.patroni.config.dynamic_configuration, separators=(',', ':'))) + self.cluster = self.dcs.get_cluster() + + if self._async_executor.busy: + return self.handle_long_action_in_progress() + + msg = self.handle_starting_instance() + if msg is not None: + return msg + + # we've got here, so any async action has finished. + if self.state_handler.bootstrapping: + return self.post_bootstrap() + + if self.recovering and not self._rewind.is_needed: + self.recovering = False + # Check if we tried to recover and failed + msg = self.post_recover() + if msg is not None: + return msg + + # is data directory empty? + if self.state_handler.data_directory_empty(): + self.state_handler.set_role('uninitialized') + self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout']) + # In case datadir went away while we were master. + self.watchdog.disable() + + # is this instance the leader? + if self.has_lock(): + self.release_leader_key_voluntarily() + return 'released leader key voluntarily as data dir empty and currently leader' + + if self.is_paused(): + return 'running with empty data directory' + return self.bootstrap() # new node + else: + # check if we are allowed to join + data_sysid = self.state_handler.sysid + if not self.sysid_valid(data_sysid): + # data directory is not empty, but no valid sysid, cluster must be broken, suggest reinit + return ("data dir for the cluster is not empty, " + "but system ID is invalid; consider doing reinitialize") + + if self.sysid_valid(self.cluster.initialize): + if self.cluster.initialize != data_sysid: + if self.is_paused(): + logger.warning('system ID has changed while in paused mode. Patroni will exit when resuming' + ' unless system ID is reset: %s != %s', self.cluster.initialize, data_sysid) + if self.has_lock(): + self.release_leader_key_voluntarily() + return 'released leader key voluntarily due to the system ID mismatch' + else: + logger.fatal('system ID mismatch, node %s belongs to a different cluster: %s != %s', + self.state_handler.name, self.cluster.initialize, data_sysid) + sys.exit(1) + elif self.cluster.is_unlocked() and not self.is_paused(): + # "bootstrap", but data directory is not empty + if not self.state_handler.cb_called and self.state_handler.is_running() \ + and not self.state_handler.is_leader(): + self._join_aborted = True + logger.error('No initialize key in DCS and PostgreSQL is running as replica, aborting start') + logger.error('Please first start Patroni on the node running as master') + sys.exit(1) + self.dcs.initialize(create_new=(self.cluster.initialize is None), sysid=data_sysid) + + if not self.state_handler.is_healthy(): + if self.is_paused(): + self.state_handler.set_state('stopped') + if self.has_lock(): + self._delete_leader() + return 'removed leader lock because postgres is not running' + # Normally we don't start Postgres in a paused state. We make an exception for the demoted primary + # that needs to be started after it had been stopped by demote. When there is no need to call rewind + # the demote code follows through to starting Postgres right away, however, in the rewind case + # it returns from demote and reaches this point to start PostgreSQL again after rewind. In that + # case it makes no sense to continue to recover() unless rewind has finished successfully. + elif self._rewind.failed or not self._rewind.executed and not \ + (self._rewind.is_needed and self._rewind.can_rewind_or_reinitialize_allowed): + return 'postgres is not running' + + if self.state_handler.state in ('running', 'starting'): + self.state_handler.set_state('crashed') + # try to start dead postgres + return self.recover() + + try: + if self.cluster.is_unlocked(): + return self.process_unhealthy_cluster() + else: + msg = self.process_healthy_cluster() + return self.evaluate_scheduled_restart() or msg + finally: + # we might not have a valid PostgreSQL connection here if another thread + # stops PostgreSQL, therefore, we only reload replication slots if no + # asynchronous processes are running (should be always the case for the master) + if not self._async_executor.busy and not self.state_handler.is_starting(): + self.state_handler.slots_handler.sync_replication_slots(self.cluster) + if not self.state_handler.cb_called: + if not self.state_handler.is_leader(): + self._rewind.trigger_check_diverged_lsn() + self.state_handler.call_nowait(ACTION_ON_START) + except DCSError: + dcs_failed = True + logger.error('Error communicating with DCS') + if not self.is_paused() and self.state_handler.is_running() and self.state_handler.is_leader(): + self.demote('offline') + return 'demoted self because DCS is not accessible and i was a leader' + return 'DCS is not accessible' + except (psycopg2.Error, PostgresConnectionException): + return 'Error communicating with PostgreSQL. Will try again later' + finally: + if not dcs_failed: + self.touch_member() + + def run_cycle(self): + with self._async_executor: + try: + info = self._run_cycle() + return (self.is_paused() and 'PAUSE: ' or '') + info + except PatroniFatalException: + raise + except Exception: + logger.exception('Unexpected exception') + return 'Unexpected exception raised, please report it as a BUG' + + def shutdown(self): + if self.is_paused(): + logger.info('Leader key is not deleted and Postgresql is not stopped due paused state') + self.watchdog.disable() + elif not self._join_aborted: + # FIXME: If stop doesn't reach safepoint quickly enough keepalive is triggered. If shutdown checkpoint + # takes longer than ttl, then leader key is lost and replication might not have sent out all xlog. + # This might not be the desired behavior of users, as a graceful shutdown of the host can mean lost data. + # We probably need to something smarter here. + disable_wd = self.watchdog.disable if self.watchdog.is_running else None + self.while_not_sync_standby(lambda: self.state_handler.stop(checkpoint=False, on_safepoint=disable_wd, + stop_timeout=self.master_stop_timeout())) + if not self.state_handler.is_running(): + if self.is_leader(): + checkpoint_location = self.state_handler.latest_checkpoint_location() + self.dcs.delete_leader(checkpoint_location) + self.touch_member() + else: + # XXX: what about when Patroni is started as the wrong user that has access to the watchdog device + # but cannot shut down PostgreSQL. Root would be the obvious example. Would be nice to not kill the + # system due to a bad config. + logger.error("PostgreSQL shutdown failed, leader key not removed." + + (" Leaving watchdog running." if self.watchdog.is_running else "")) + + def watch(self, timeout): + # watch on leader key changes if the postgres is running and leader is known and current node is not lock owner + if self._async_executor.busy or not self.cluster or self.cluster.is_unlocked() or self.has_lock(False): + leader_index = None + else: + leader_index = self.cluster.leader.index + + return self.dcs.watch(leader_index, timeout) + + def wakeup(self): + """Call of this method will trigger the next run of HA loop if there is + no "active" leader watch request in progress. + This usually happens on the master or if the node is running async action""" + self.dcs.event.set() + + def get_remote_member(self, member=None): + """ In case of standby cluster this will tel us from which remote + master to stream. Config can be both patroni config or + cluster.config.data + """ + cluster_params = self.get_standby_cluster_config() + + if cluster_params: + name = member.name if member else 'remote_master:{}'.format(uuid.uuid1()) + + data = {k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()} + data['no_replication_slot'] = 'primary_slot_name' not in cluster_params + conn_kwargs = member.conn_kwargs() if member else \ + {k: cluster_params[k] for k in ('host', 'port') if k in cluster_params} + if conn_kwargs: + data['conn_kwargs'] = conn_kwargs + + return RemoteMember(name, data) + + def get_remote_master(self): + return self.get_remote_member() diff --git a/patroni-for-openGauss/log.py b/patroni-for-openGauss/log.py new file mode 100644 index 0000000000000000000000000000000000000000..ce808bb1d5fdc6df48bb716364f28d6fb8f04f9e --- /dev/null +++ b/patroni-for-openGauss/log.py @@ -0,0 +1,194 @@ +import logging +import os +import sys + +from copy import deepcopy +from logging.handlers import RotatingFileHandler +from patroni.utils import deep_compare +from six.moves.queue import Queue, Full +from threading import Lock, Thread + +_LOGGER = logging.getLogger(__name__) + + +def debug_exception(logger_obj, msg, *args, **kwargs): + kwargs.pop("exc_info", False) + if logger_obj.isEnabledFor(logging.DEBUG): + logger_obj.debug(msg, *args, exc_info=True, **kwargs) + else: + msg = "{0}, DETAIL: '{1}'".format(msg, sys.exc_info()[1]) + logger_obj.error(msg, *args, exc_info=False, **kwargs) + + +def error_exception(logger_obj, msg, *args, **kwargs): + exc_info = kwargs.pop("exc_info", True) + logger_obj.error(msg, *args, exc_info=exc_info, **kwargs) + + +class QueueHandler(logging.Handler): + + def __init__(self): + logging.Handler.__init__(self) + self.queue = Queue() + self._records_lost = 0 + + def _put_record(self, record): + self.format(record) + record.msg = record.message + record.args = None + record.exc_info = None + self.queue.put_nowait(record) + + def _try_to_report_lost_records(self): + if self._records_lost: + try: + record = _LOGGER.makeRecord(_LOGGER.name, logging.WARNING, __file__, 0, + 'QueueHandler has lost %s log records', + (self._records_lost,), None, 'emit') + self._put_record(record) + self._records_lost = 0 + except Exception: + pass + + def emit(self, record): + try: + self._put_record(record) + self._try_to_report_lost_records() + except Exception: + self._records_lost += 1 + + @property + def records_lost(self): + return self._records_lost + + +class ProxyHandler(logging.Handler): + + def __init__(self, patroni_logger): + logging.Handler.__init__(self) + self.patroni_logger = patroni_logger + + def emit(self, record): + self.patroni_logger.log_handler.handle(record) + + +class PatroniLogger(Thread): + + DEFAULT_LEVEL = 'INFO' + DEFAULT_TRACEBACK_LEVEL = 'ERROR' + DEFAULT_FORMAT = '%(asctime)s %(levelname)s: %(message)s' + + NORMAL_LOG_QUEUE_SIZE = 2 # When everything goes normal Patroni writes only 2 messages per HA loop + DEFAULT_MAX_QUEUE_SIZE = 1000 + LOGGING_BROKEN_EXIT_CODE = 5 + + def __init__(self): + super(PatroniLogger, self).__init__() + self._queue_handler = QueueHandler() + self._root_logger = logging.getLogger() + self._config = None + self.log_handler = None + self.log_handler_lock = Lock() + self._old_handlers = [] + self.reload_config({'level': 'DEBUG'}) + # We will switch to the QueueHandler only when thread was started. + # This is necessary to protect from the cases when Patroni constructor + # failed and PatroniLogger thread remain running and prevent shutdown. + self._proxy_handler = ProxyHandler(self) + self._root_logger.addHandler(self._proxy_handler) + + def update_loggers(self): + loggers = deepcopy(self._config.get('loggers') or {}) + for name, logger in self._root_logger.manager.loggerDict.items(): + if not isinstance(logger, logging.PlaceHolder): + level = loggers.pop(name, logging.NOTSET) + logger.setLevel(level) + + for name, level in loggers.items(): + logger = self._root_logger.manager.getLogger(name) + logger.setLevel(level) + + def reload_config(self, config): + if self._config is None or not deep_compare(self._config, config): + with self._queue_handler.queue.mutex: + self._queue_handler.queue.maxsize = config.get('max_queue_size', self.DEFAULT_MAX_QUEUE_SIZE) + + self._root_logger.setLevel(config.get('level', PatroniLogger.DEFAULT_LEVEL)) + if config.get('traceback_level', PatroniLogger.DEFAULT_TRACEBACK_LEVEL).lower() == 'debug': + logging.Logger.exception = debug_exception + else: + logging.Logger.exception = error_exception + + new_handler = None + if 'dir' in config: + if not isinstance(self.log_handler, RotatingFileHandler): + new_handler = RotatingFileHandler(os.path.join(config['dir'], __name__)) + handler = new_handler or self.log_handler + handler.maxBytes = int(config.get('file_size', 25000000)) + handler.backupCount = int(config.get('file_num', 4)) + else: + if self.log_handler is None or isinstance(self.log_handler, RotatingFileHandler): + new_handler = logging.StreamHandler() + handler = new_handler or self.log_handler + + oldlogformat = (self._config or {}).get('format', PatroniLogger.DEFAULT_FORMAT) + logformat = config.get('format', PatroniLogger.DEFAULT_FORMAT) + + olddateformat = (self._config or {}).get('dateformat') or None + dateformat = config.get('dateformat') or None # Convert empty string to `None` + + if oldlogformat != logformat or olddateformat != dateformat or new_handler: + handler.setFormatter(logging.Formatter(logformat, dateformat)) + + if new_handler: + with self.log_handler_lock: + if self.log_handler: + self._old_handlers.append(self.log_handler) + self.log_handler = new_handler + + self._config = config.copy() + self.update_loggers() + + def _close_old_handlers(self): + while True: + with self.log_handler_lock: + if not self._old_handlers: + break + handler = self._old_handlers.pop() + try: + handler.close() + except Exception: + _LOGGER.exception('Failed to close the old log handler %s', handler) + + def run(self): + # switch to QueueHandler only when the thread was started + with self.log_handler_lock: + self._root_logger.addHandler(self._queue_handler) + self._root_logger.removeHandler(self._proxy_handler) + + while True: + self._close_old_handlers() + + record = self._queue_handler.queue.get(True) + if record is None: + break + + self.log_handler.handle(record) + self._queue_handler.queue.task_done() + + def shutdown(self): + try: + self._queue_handler.queue.put_nowait(None) + except Full: # Queue is full. + # It seems that logging is not working, exiting with non-standard exit-code is the best we can do. + sys.exit(self.LOGGING_BROKEN_EXIT_CODE) + self.join() + logging.shutdown() + + @property + def queue_size(self): + return self._queue_handler.queue.qsize() + + @property + def records_lost(self): + return self._queue_handler.records_lost diff --git a/patroni-for-openGauss/postgresql/__init__.py b/patroni-for-openGauss/postgresql/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10004ced5ceb8584a6eb5bdf6bf4d18a83ef1feb --- /dev/null +++ b/patroni-for-openGauss/postgresql/__init__.py @@ -0,0 +1,1051 @@ +import logging +import os +import psycopg2 +import shlex +import shutil +import subprocess +import time +import threading + +from contextlib import contextmanager +from copy import deepcopy +from dateutil import tz +from datetime import datetime +from patroni.postgresql.callback_executor import CallbackExecutor +from patroni.postgresql.bootstrap import Bootstrap +from patroni.postgresql.cancellable import CancellableSubprocess +from patroni.postgresql.config import ConfigHandler, mtime +from patroni.postgresql.connection import Connection, get_connection_cursor +from patroni.postgresql.misc import parse_history, parse_lsn, postgres_major_version_to_int +from patroni.postgresql.postmaster import PostmasterProcess +from patroni.postgresql.slots import SlotsHandler +from patroni.exceptions import PostgresConnectionException +from patroni.utils import Retry, RetryFailedError, polling_loop, data_directory_is_empty, parse_int +from psutil import TimeoutExpired +from threading import current_thread, Lock + + +logger = logging.getLogger(__name__) + +ACTION_ON_START = "on_start" +ACTION_ON_STOP = "on_stop" +ACTION_ON_RESTART = "on_restart" +ACTION_ON_RELOAD = "on_reload" +ACTION_ON_ROLE_CHANGE = "on_role_change" +ACTION_NOOP = "noop" + +STATE_RUNNING = 'running' +STATE_REJECT = 'rejecting connections' +STATE_NO_RESPONSE = 'not responding' +STATE_UNKNOWN = 'unknown' + +STOP_POLLING_INTERVAL = 1 + + +@contextmanager +def null_context(): + yield + + +class Postgresql(object): + + POSTMASTER_START_TIME = "pg_catalog.to_char(pg_catalog.pg_postmaster_start_time(), 'YYYY-MM-DD HH24:MI:SS.MS TZ')" + TL_LSN = ("CASE WHEN pg_catalog.pg_is_in_recovery() THEN 0 " + "ELSE ('x' || pg_catalog.substr(pg_catalog.pg_{0}file_name(" + "pg_catalog.pg_current_{0}_{1}()), 1, 8))::bit(32)::int END, " # master timeline + "CASE WHEN pg_catalog.pg_is_in_recovery() THEN 0 " + "ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), '0/0')::bigint END, " # write_lsn + "pg_catalog.pg_{0}_{1}_diff(split_part(left(pg_catalog.pg_last_{0}_replay_{1}()::text,-1),',',2), '0/0')::bigint, " + "pg_catalog.pg_{0}_{1}_diff(COALESCE(pg_catalog.pg_last_{0}_receive_{1}(), '0/0'), '0/0')::bigint, " + "pg_catalog.pg_is_in_recovery() AND pg_catalog.pg_is_{0}_replay_paused()") + + def __init__(self, config): + self.name = config['name'] + self.scope = config['scope'] + self._data_dir = config['data_dir'] + self._database = config.get('database', 'postgres') + self._version_file = os.path.join(self._data_dir, 'PG_VERSION') + self._pg_control = os.path.join(self._data_dir, 'global', 'pg_control') + self._major_version = self.get_major_version() + + self._state_lock = Lock() + self.set_state('stopped') + + self._pending_restart = False + self._connection = Connection() + self.config = ConfigHandler(self, config) + self.config.check_directories() + + self._bin_dir = config.get('bin_dir') or '' + self.bootstrap = Bootstrap(self) + self.bootstrapping = False + self.__thread_ident = current_thread().ident + + self.slots_handler = SlotsHandler(self) + + self._callback_executor = CallbackExecutor() + self.__cb_called = False + self.__cb_pending = None + + self.cancellable = CancellableSubprocess() + + self._sysid = None + self.retry = Retry(max_tries=-1, deadline=config['retry_timeout']/2.0, max_delay=1, + retry_exceptions=PostgresConnectionException) + + # Retry 'pg_is_in_recovery()' only once + self._is_leader_retry = Retry(max_tries=1, deadline=config['retry_timeout']/2.0, max_delay=1, + retry_exceptions=PostgresConnectionException) + + self._role_lock = Lock() + self.set_role(self.get_postgres_role_from_data_directory()) + self._state_entry_timestamp = None + + self._cluster_info_state = {} + self._cached_replica_timeline = None + + # Last known running process + self._postmaster_proc = None + + if self.is_running(): + self.set_state('running') + self.set_role('master' if self.is_leader() else 'replica') + self.config.write_postgresql_conf() # we are "joining" already running postgres + hba_saved = self.config.replace_pg_hba() + ident_saved = self.config.replace_pg_ident() + if hba_saved or ident_saved: + self.reload() + elif self.role == 'master': + self.set_role('demoted') + + @property + def create_replica_methods(self): + return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', []) + + @property + def major_version(self): + return self._major_version + + @property + def database(self): + return self._database + + @property + def data_dir(self): + return self._data_dir + + @property + def callback(self): + return self.config.get('callbacks') or {} + + @property + def wal_dir(self): + return os.path.join(self._data_dir, 'pg_' + self.wal_name) + + @property + def wal_name(self): + return 'wal' if self._major_version >= 100000 else 'xlog' + + @property + def lsn_name(self): + return 'lsn' if self._major_version >= 100000 else 'location' + + @property + def cluster_info_query(self): + if self._major_version >= 90600: + extra = (", CASE WHEN latest_end_lsn IS NULL THEN NULL ELSE received_tli END," + " slot_name, conninfo FROM pg_catalog.pg_stat_get_wal_receiver()") + if self.role == 'standby_leader': + extra = "timeline_id" + extra + ", pg_catalog.pg_control_checkpoint()" + else: + extra = "0" + extra + else: + extra = "0, NULL, NULL, NULL" + + return ("SELECT " + self.TL_LSN + ", {2}").format(self.wal_name, self.lsn_name, extra) + + def _version_file_exists(self): + return not self.data_directory_empty() and os.path.isfile(self._version_file) + + def get_major_version(self): + if self._version_file_exists(): + try: + with open(self._version_file) as f: + return postgres_major_version_to_int(f.read().strip()) + except Exception: + logger.exception('Failed to read PG_VERSION from %s', self._data_dir) + return 0 + + def pgcommand(self, cmd): + """Returns path to the specified PostgreSQL command""" + return os.path.join(self._bin_dir, cmd) + + def pg_ctl(self, cmd, *args, **kwargs): + """Builds and executes pg_ctl command + + :returns: `!True` when return_code == 0, otherwise `!False`""" + + pg_ctl = [self.pgcommand('gs_ctl'), cmd] + return subprocess.call(pg_ctl + ['-D', self._data_dir] + list(args), **kwargs) == 0 + + def gs_query(self): + """ + query the state of the database + """ + status, output = subprocess.getstatusoutput(self.pgcommand('gs_ctl') + ' query -D ' + self._data_dir) + return status, output + + def pg_isready(self): + """Runs pg_isready to see if PostgreSQL is accepting connections. + + :returns: 'ok' if PostgreSQL is up, 'reject' if starting up, 'no_resopnse' if not up.""" + + r = self.config.local_connect_kwargs + cmd = [self.pgcommand('gs_isready'), '-p', r['port'], '-d', self._database] + + # Host is not set if we are connecting via default unix socket + if 'host' in r: + cmd.extend(['-h', r['host']]) + + # We only need the username because pg_isready does not try to authenticate + if 'user' in r: + cmd.extend(['-U', r['user']]) + + ret = subprocess.call(cmd) + return_codes = {0: STATE_RUNNING, + 1: STATE_REJECT, + 2: STATE_NO_RESPONSE, + 3: STATE_UNKNOWN} + return return_codes.get(ret, STATE_UNKNOWN) + + def reload_config(self, config, sighup=False): + self.config.reload_config(config, sighup) + self._is_leader_retry.deadline = self.retry.deadline = config['retry_timeout']/2.0 + + @property + def pending_restart(self): + return self._pending_restart + + def set_pending_restart(self, value): + self._pending_restart = value + + @property + def sysid(self): + if not self._sysid and not self.bootstrapping: + data = self.controldata() + self._sysid = data.get('Database system identifier', "") + return self._sysid + + def get_postgres_role_from_data_directory(self): + if self.data_directory_empty() or not self.controldata(): + return 'uninitialized' + elif self.config.recovery_conf_exists(): + return 'replica' + else: + return 'master' + + @property + def server_version(self): + return self._connection.server_version + + def connection(self): + return self._connection.get() + + def set_connection_kwargs(self, kwargs): + self._connection.set_conn_kwargs(kwargs) + + def _query(self, sql, *params): + """We are always using the same cursor, therefore this method is not thread-safe!!! + You can call it from different threads only if you are holding explicit `AsyncExecutor` lock, + because the main thread is always holding this lock when running HA cycle.""" + cursor = None + try: + cursor = self._connection.cursor() + cursor.execute(sql, params) + return cursor + except psycopg2.Error as e: + if cursor and cursor.connection.closed == 0: + # When connected via unix socket, psycopg2 can't recoginze 'connection lost' + # and leaves `_cursor_holder.connection.closed == 0`, but psycopg2.OperationalError + # is still raised (what is correct). It doesn't make sense to continiue with existing + # connection and we will close it, to avoid its reuse by the `cursor` method. + if isinstance(e, psycopg2.OperationalError): + self._connection.close() + else: + raise e + if self.state == 'restarting': + raise RetryFailedError('cluster is being restarted') + raise PostgresConnectionException('connection problems') + + def query(self, sql, *args, **kwargs): + if not kwargs.get('retry', True): + return self._query(sql, *args) + try: + return self.retry(self._query, sql, *args) + except RetryFailedError as e: + raise PostgresConnectionException(str(e)) + + def pg_control_exists(self): + return os.path.isfile(self._pg_control) + + def data_directory_empty(self): + if self.pg_control_exists(): + return False + return data_directory_is_empty(self._data_dir) + + def replica_method_options(self, method): + return deepcopy(self.config.get(method, {})) + + def replica_method_can_work_without_replication_connection(self, method): + return method != 'basebackup' and self.replica_method_options(method).get('no_master') + + def can_create_replica_without_replication_connection(self, replica_methods=None): + """ go through the replication methods to see if there are ones + that does not require a working replication connection. + """ + if replica_methods is None: + replica_methods = self.create_replica_methods + return any(self.replica_method_can_work_without_replication_connection(m) for m in replica_methods) + + def reset_cluster_info_state(self): + self._cluster_info_state = {} + + def _cluster_info_state_get(self, name): + if not self._cluster_info_state: + try: + result = self._is_leader_retry(self._query, self.cluster_info_query).fetchone() + self._cluster_info_state = dict(zip(['timeline', 'wal_position', 'replayed_location', + 'received_location', 'replay_paused', 'pg_control_timeline', + 'received_tli', 'slot_name', 'conninfo'], result)) + except RetryFailedError as e: # SELECT failed two times + self._cluster_info_state = {'error': str(e)} + if not self.is_starting() and self.pg_isready() == STATE_REJECT: + self.set_state('starting') + + if 'error' in self._cluster_info_state: + raise PostgresConnectionException(self._cluster_info_state['error']) + + return self._cluster_info_state.get(name) + + def replayed_location(self): + return self._cluster_info_state_get('replayed_location') + + def received_location(self): + return self._cluster_info_state_get('received_location') + + def primary_slot_name(self): + return self._cluster_info_state_get('slot_name') + + def primary_conninfo(self): + return self._cluster_info_state_get('conninfo') + + def received_timeline(self): + return self._cluster_info_state_get('received_tli') + + def is_leader(self): + return bool(self._cluster_info_state_get('timeline')) + + def pg_control_timeline(self): + try: + return int(self.controldata().get("Latest checkpoint's TimeLineID")) + except (TypeError, ValueError): + logger.exception('Failed to parse timeline from pg_controldata output') + + def latest_checkpoint_location(self): + """Returns checkpoint location for the cleanly shut down primary""" + + data = self.controldata() + lsn = data.get('Latest checkpoint location') + if data.get('Database cluster state') == 'shut down' and lsn: + try: + return str(parse_lsn(lsn)) + except (IndexError, ValueError) as e: + logger.error('Exception when parsing lsn %s: %r', lsn, e) + + def is_running(self): + """Returns PostmasterProcess if one is running on the data directory or None. If most recently seen process + is running updates the cached process based on pid file.""" + if self._postmaster_proc: + if self._postmaster_proc.is_running(): + return self._postmaster_proc + self._postmaster_proc = None + + # we noticed that postgres was restarted, force syncing of replication + self.slots_handler.schedule() + + self._postmaster_proc = PostmasterProcess.from_pidfile(self._data_dir) + return self._postmaster_proc + + @property + def cb_called(self): + return self.__cb_called + + def call_nowait(self, cb_name): + """ pick a callback command and call it without waiting for it to finish """ + if self.bootstrapping: + return + if cb_name in (ACTION_ON_START, ACTION_ON_STOP, ACTION_ON_RESTART, ACTION_ON_ROLE_CHANGE): + self.__cb_called = True + + if self.callback and cb_name in self.callback: + cmd = self.callback[cb_name] + try: + cmd = shlex.split(self.callback[cb_name]) + [cb_name, self.role, self.scope] + self._callback_executor.call(cmd) + except Exception: + logger.exception('callback %s %s %s %s failed', cmd, cb_name, self.role, self.scope) + + @property + def role(self): + with self._role_lock: + return self._role + + def set_role(self, value): + with self._role_lock: + self._role = value + + @property + def state(self): + with self._state_lock: + return self._state + + def set_state(self, value): + with self._state_lock: + self._state = value + self._state_entry_timestamp = time.time() + + def time_in_state(self): + return time.time() - self._state_entry_timestamp + + def is_starting(self): + return self.state == 'starting' + + def wait_for_port_open(self, postmaster, timeout): + """Waits until PostgreSQL opens ports.""" + for _ in polling_loop(timeout): + if self.cancellable.is_cancelled: + return False + + if not postmaster.is_running(): + logger.error('postmaster is not running') + self.set_state('start failed') + return False + + isready = self.pg_isready() + if isready != STATE_NO_RESPONSE: + if isready not in [STATE_REJECT, STATE_RUNNING]: + logger.warning("Can't determine PostgreSQL startup status, assuming running") + return True + + logger.warning("Timed out waiting for PostgreSQL to start") + return False + + def start(self, timeout=None, task=None, block_callbacks=False, role=None, is_standby=False): + """Start PostgreSQL + + Waits for postmaster to open ports or terminate so pg_isready can be used to check startup completion + or failure. + + :returns: True if start was initiated and postmaster ports are open, False if start failed""" + # make sure we close all connections established against + # the former node, otherwise, we might get a stalled one + # after kill -9, which would report incorrect data to + # patroni. + self._connection.close() + + if self.is_running(): + logger.error('Cannot start PostgreSQL because one is already running.') + self.set_state('starting') + return True + + if not block_callbacks: + self.__cb_pending = ACTION_ON_START + + self.set_role(role or self.get_postgres_role_from_data_directory()) + + self.set_state('starting') + self._pending_restart = False + + try: + if not self._major_version: + self.configure_server_parameters() + configuration = self.config.effective_configuration + except Exception: + return None + + self.config.check_directories() + self.config.write_postgresql_conf(configuration) + self.config.resolve_connection_addresses() + self.config.replace_pg_hba() + self.config.replace_pg_ident() + + options = [] + if is_standby: + options += ['-M', 'standby'] + + if self.cancellable.is_cancelled: + return False + + with task or null_context(): + if task and task.is_cancelled: + logger.info("openGauss start cancelled.") + return False + + self._postmaster_proc = PostmasterProcess.start(self.pgcommand('gaussdb'), + self._data_dir, + self.config.postgresql_conf, + options) + + if task: + task.complete(self._postmaster_proc) + + start_timeout = timeout + if not start_timeout: + try: + start_timeout = float(self.config.get('pg_ctl_timeout', 60)) + except ValueError: + start_timeout = 60 + + # We want postmaster to open ports before we continue + if not self._postmaster_proc or not self.wait_for_port_open(self._postmaster_proc, start_timeout): + return False + + ret = self.wait_for_startup(start_timeout) + if ret is not None: + return ret + elif timeout is not None: + return False + else: + return None + + def checkpoint(self, connect_kwargs=None, timeout=None): + check_not_is_in_recovery = connect_kwargs is not None + connect_kwargs = connect_kwargs or self.config.local_connect_kwargs + for p in ['connect_timeout', 'options']: + connect_kwargs.pop(p, None) + if timeout: + connect_kwargs['connect_timeout'] = timeout + try: + with get_connection_cursor(**connect_kwargs) as cur: + cur.execute("SET statement_timeout = 0") + if check_not_is_in_recovery: + cur.execute('SELECT pg_catalog.pg_is_in_recovery()') + if cur.fetchone()[0]: + return 'is_in_recovery=true' + return cur.execute('CHECKPOINT') + except psycopg2.Error: + logger.exception('Exception during CHECKPOINT') + return 'not accessible or not healty' + + def rebuild(self): + """ + standby is abnormal, rebuild + """ + if not self.is_running(): + self.pg_ctl('start', '-M', 'standby') + rebuild_thread = threading.Thread(target=self.pg_ctl, args=('build', '-M', 'standby')) + rebuild_thread.start() + rebuild_thread.join() + + def stop(self, mode='fast', block_callbacks=False, checkpoint=None, on_safepoint=None, stop_timeout=None): + """Stop PostgreSQL + + Supports a callback when a safepoint is reached. A safepoint is when no user backend can return a successful + commit to users. Currently this means we wait for user backends to close. But in the future alternate mechanisms + could be added. + + :param on_safepoint: This callback is called when no user backends are running. + """ + if checkpoint is None: + checkpoint = False if mode == 'immediate' else True + + success, pg_signaled = self._do_stop(mode, block_callbacks, checkpoint, on_safepoint, stop_timeout) + if success: + # block_callbacks is used during restart to avoid + # running start/stop callbacks in addition to restart ones + if not block_callbacks: + self.set_state('stopped') + if pg_signaled: + self.call_nowait(ACTION_ON_STOP) + else: + logger.warning('pg_ctl stop failed') + self.set_state('stop failed') + return success + + def _do_stop(self, mode, block_callbacks, checkpoint, on_safepoint, stop_timeout): + postmaster = self.is_running() + if not postmaster: + if on_safepoint: + on_safepoint() + return True, False + + if checkpoint and not self.is_starting(): + self.checkpoint(timeout=stop_timeout) + + if not block_callbacks: + self.set_state('stopping') + + # Send signal to postmaster to stop + success = postmaster.signal_stop(mode, self.pgcommand('gs_ctl')) + if success is not None: + if success and on_safepoint: + on_safepoint() + return success, True + + # We can skip safepoint detection if we don't have a callback + if on_safepoint: + # Wait for our connection to terminate so we can be sure that no new connections are being initiated + self._wait_for_connection_close(postmaster) + postmaster.wait_for_user_backends_to_close() + on_safepoint() + + try: + postmaster.wait(timeout=stop_timeout) + except TimeoutExpired: + logger.warning("Timeout during postmaster stop, aborting Postgres.") + if not self.terminate_postmaster(postmaster, mode, stop_timeout): + postmaster.wait() + + return True, True + + def terminate_postmaster(self, postmaster, mode, stop_timeout): + if mode in ['fast', 'smart']: + try: + success = postmaster.signal_stop('immediate', self.pgcommand('gs_ctl')) + if success: + return True + postmaster.wait(timeout=stop_timeout) + return True + except TimeoutExpired: + pass + logger.warning("Sending SIGKILL to Postmaster and its children") + return postmaster.signal_kill() + + def terminate_starting_postmaster(self, postmaster): + """Terminates a postmaster that has not yet opened ports or possibly even written a pid file. Blocks + until the process goes away.""" + postmaster.signal_stop('immediate', self.pgcommand('gs_ctl')) + postmaster.wait() + + def _wait_for_connection_close(self, postmaster): + try: + with self.connection().cursor() as cur: + while postmaster.is_running(): # Need a timeout here? + cur.execute("SELECT 1") + time.sleep(STOP_POLLING_INTERVAL) + except psycopg2.Error: + pass + + def reload(self, block_callbacks=False): + ret = self.pg_ctl('reload') + if ret and not block_callbacks: + self.call_nowait(ACTION_ON_RELOAD) + return ret + + def check_for_startup(self): + """Checks PostgreSQL status and returns if PostgreSQL is in the middle of startup.""" + return self.is_starting() and not self.check_startup_state_changed() + + def check_startup_state_changed(self): + """Checks if PostgreSQL has completed starting up or failed or still starting. + + Should only be called when state == 'starting' + + :returns: True if state was changed from 'starting' + """ + ready = self.pg_isready() + + if ready == STATE_REJECT: + return False + elif ready == STATE_NO_RESPONSE: + ret = not self.is_running() + if ret: + self.set_state('start failed') + self.slots_handler.schedule(False) # TODO: can remove this? + self.config.save_configuration_files(True) # TODO: maybe remove this? + return ret + else: + if ready != STATE_RUNNING: + # Bad configuration or unexpected OS error. No idea of PostgreSQL status. + # Let the main loop of run cycle clean up the mess. + logger.warning("%s status returned from pg_isready", + "Unknown" if ready == STATE_UNKNOWN else "Invalid") + self.set_state('running') + self.slots_handler.schedule() + self.config.save_configuration_files(True) + # TODO: __cb_pending can be None here after PostgreSQL restarts on its own. Do we want to call the callback? + # Previously we didn't even notice. + action = self.__cb_pending or ACTION_ON_START + self.call_nowait(action) + self.__cb_pending = None + + return True + + def wait_for_startup(self, timeout=None): + """Waits for PostgreSQL startup to complete or fail. + + :returns: True if start was successful, False otherwise""" + if not self.is_starting(): + # Should not happen + logger.warning("wait_for_startup() called when not in starting state") + + while not self.check_startup_state_changed(): + if self.cancellable.is_cancelled or timeout and self.time_in_state() > timeout: + return None + time.sleep(1) + + return self.state == 'running' + + def restart(self, timeout=None, task=None, block_callbacks=False, role=None, is_standby=False): + """Restarts PostgreSQL. + + When timeout parameter is set the call will block either until PostgreSQL has started, failed to start or + timeout arrives. + + :returns: True when restart was successful and timeout did not expire when waiting. + """ + self.set_state('restarting') + if not block_callbacks: + self.__cb_pending = ACTION_ON_RESTART + ret = 0 + if is_standby: + ret = self.stop(block_callbacks=True) and self.start(timeout, task, True, role, is_standby) + else: + ret = self.stop(block_callbacks=True) and self.start(timeout, task, True, role) + if not ret and not self.is_starting(): + self.set_state('restart failed ({0})'.format(self.state)) + return ret + + def is_healthy(self): + if not self.is_running(): + logger.warning('openGauss is not running.') + return False + return True + + def get_guc_value(self, name): + cmd = [self.pgcommand('gaussdb'), '-D', self._data_dir, '-C', name] + try: + data = subprocess.check_output(cmd) + if data: + return data.decode('utf-8').strip() + except Exception as e: + logger.error('Failed to execute %s: %r', cmd, e) + + def controldata(self): + """ return the contents of pg_controldata, or non-True value if pg_controldata call failed """ + # Don't try to call pg_controldata during backup restore + if self._version_file_exists() and self.state != 'creating replica': + try: + env = os.environ.copy() + env.update(LANG='C', LC_ALL='C') + data = subprocess.check_output([self.pgcommand('pg_controldata'), self._data_dir], env=env) + if data: + data = filter(lambda e: ':' in e, data.decode('utf-8').splitlines()) + # pg_controldata output depends on major version. Some of parameters are prefixed by 'Current ' + return {k.replace('Current ', '', 1): v.strip() for k, v in map(lambda e: e.split(':', 1), data)} + except subprocess.CalledProcessError: + logger.exception("Error when calling pg_controldata") + return {} + + @contextmanager + def get_replication_connection_cursor(self, host='localhost', port=5432, **kwargs): + conn_kwargs = self.config.replication.copy() + conn_kwargs.update(host=host, port=int(port) if port else None, user=conn_kwargs.pop('username'), + connect_timeout=3, replication=1, options='-c statement_timeout=2000') + with get_connection_cursor(**conn_kwargs) as cur: + yield cur + + def get_replica_timeline(self): + try: + with self.get_replication_connection_cursor(**self.config.local_replication_address) as cur: + cur.execute('IDENTIFY_SYSTEM') + return cur.fetchone()[1] + except Exception: + logger.exception('Can not fetch local timeline and lsn from replication connection') + + def replica_cached_timeline(self, master_timeline): + if not self._cached_replica_timeline or not master_timeline or self._cached_replica_timeline != master_timeline: + self._cached_replica_timeline = self.get_replica_timeline() + return self._cached_replica_timeline + + def get_master_timeline(self): + return self._cluster_info_state_get('timeline') + + def get_history(self, timeline): + history_path = os.path.join(self.wal_dir, '{0:08X}.history'.format(timeline)) + history_mtime = mtime(history_path) + if history_mtime: + try: + with open(history_path, 'r') as f: + history = f.read() + history = list(parse_history(history)) + if history[-1][0] == timeline - 1: + history_mtime = datetime.fromtimestamp(history_mtime).replace(tzinfo=tz.tzlocal()) + history[-1].append(history_mtime.isoformat()) + return history + except Exception: + logger.exception('Failed to read and parse %s', (history_path,)) + + def follow(self, member, role='replica', timeout=None, do_reload=False, is_standby=False, rebuild=False): + # When we demoting the master or standby_leader to replica or promoting replica to a standby_leader + # and we know for sure that postgres was already running before, we will only execute on_role_change + # callback and prevent execution of on_restart/on_start callback. + # If the role remains the same (replica or standby_leader), we will execute on_start or on_restart + change_role = self.cb_called and (self.role in ('master', 'demoted') or + not {'standby_leader', 'replica'} - {self.role, role}) + if change_role: + self.__cb_pending = ACTION_NOOP + + if self.is_running(): + if is_standby and rebuild: + self.rebuild() + elif do_reload: + self.config.write_postgresql_conf() + if self.reload(block_callbacks=change_role) and change_role: + self.set_role(role) + else: + if is_standby: + self.restart(block_callbacks=change_role, role=role, is_standby=True) + else: + self.restart(block_callbacks=change_role, role=role) + else: + if is_standby: + self.start(timeout=timeout, block_callbacks=change_role, role=role, is_standby=True) + else: + self.start(timeout=timeout, block_callbacks=change_role, role=role) + + if change_role: + # TODO: postpone this until start completes, or maybe do even earlier + self.call_nowait(ACTION_ON_ROLE_CHANGE) + return True + + def _wait_promote(self, wait_seconds): + for _ in polling_loop(wait_seconds): + data = self.controldata() + if data.get('Database cluster state') == 'in production': + return True + + def _pre_promote(self): + """ + Runs a fencing script after the leader lock is acquired but before the replica is promoted. + If the script exits with a non-zero code, promotion does not happen and the leader key is removed from DCS. + """ + + cmd = self.config.get('pre_promote') + if not cmd: + return True + + ret = self.cancellable.call(shlex.split(cmd)) + if ret is not None: + logger.info('pre_promote script `%s` exited with %s', cmd, ret) + return ret == 0 + + def promote(self, wait_seconds, task, on_success=None, access_is_restricted=False): + if self.role == 'master': + return True + + ret = self._pre_promote() + with task: + if task.is_cancelled: + return False + task.complete(ret) + + if ret is False: + return False + + if self.cancellable.is_cancelled: + logger.info("PostgreSQL promote cancelled.") + return False + + ret = self.pg_ctl('failover') + if ret: + self.set_role('master') + if on_success is not None: + on_success() + if not access_is_restricted: + self.call_nowait(ACTION_ON_ROLE_CHANGE) + ret = self._wait_promote(wait_seconds) + return ret + + @staticmethod + def _wal_position(is_leader, wal_position, received_location, replayed_location): + return wal_position if is_leader else max(received_location or 0, replayed_location or 0) + + def timeline_wal_position(self): + # This method could be called from different threads (simultaneously with some other `_query` calls). + # If it is called not from main thread we will create a new cursor to execute statement. + if current_thread().ident == self.__thread_ident: + timeline = self._cluster_info_state_get('timeline') + wal_position = self._cluster_info_state_get('wal_position') + replayed_location = self.replayed_location() + received_location = self.received_location() + pg_control_timeline = self._cluster_info_state_get('pg_control_timeline') + else: + with self.connection().cursor() as cursor: + cursor.execute(self.cluster_info_query) + (timeline, wal_position, replayed_location, + received_location, _, pg_control_timeline) = cursor.fetchone()[:6] + + wal_position = self._wal_position(timeline, wal_position, received_location, replayed_location) + return (timeline, wal_position, pg_control_timeline) + + def postmaster_start_time(self): + try: + query = "SELECT " + self.POSTMASTER_START_TIME + if current_thread().ident == self.__thread_ident: + return self.query(query).fetchone()[0] + with self.connection().cursor() as cursor: + cursor.execute(query) + return cursor.fetchone()[0] + except psycopg2.Error: + return None + + def last_operation(self): + return str(self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position'), + self.received_location(), self.replayed_location())) + + def configure_server_parameters(self): + self._major_version = self.get_major_version() + self.config.setup_server_parameters() + return True + + def pg_wal_realpath(self): + """Returns a dict containing the symlink (key) and target (value) for the wal directory""" + links = {} + for pg_wal_dir in ('pg_xlog', 'pg_wal'): + pg_wal_path = os.path.join(self._data_dir, pg_wal_dir) + if os.path.exists(pg_wal_path) and os.path.islink(pg_wal_path): + pg_wal_realpath = os.path.realpath(pg_wal_path) + links[pg_wal_path] = pg_wal_realpath + return links + + def pg_tblspc_realpaths(self): + """Returns a dict containing the symlink (key) and target (values) for the tablespaces""" + links = {} + pg_tblsp_dir = os.path.join(self._data_dir, 'pg_tblspc') + if os.path.exists(pg_tblsp_dir): + for tsdn in os.listdir(pg_tblsp_dir): + pg_tsp_path = os.path.join(pg_tblsp_dir, tsdn) + if parse_int(tsdn) and os.path.islink(pg_tsp_path): + pg_tsp_rpath = os.path.realpath(pg_tsp_path) + links[pg_tsp_path] = pg_tsp_rpath + return links + + def move_data_directory(self): + if os.path.isdir(self._data_dir) and not self.is_running(): + try: + postfix = time.strftime('%Y-%m-%d-%H-%M-%S') + + # let's see if the wal directory is a symlink, in this case we + # should move the target + for (source, pg_wal_realpath) in self.pg_wal_realpath().items(): + logger.info('renaming WAL directory and updating symlink: %s', pg_wal_realpath) + new_name = '{0}_{1}'.format(pg_wal_realpath, postfix) + os.rename(pg_wal_realpath, new_name) + os.unlink(source) + os.symlink(new_name, source) + + # Move user defined tablespace directory + for (source, pg_tsp_rpath) in self.pg_tblspc_realpaths().items(): + logger.info('renaming user defined tablespace directory and updating symlink: %s', pg_tsp_rpath) + new_name = '{0}_{1}'.format(pg_tsp_rpath, postfix) + os.rename(pg_tsp_rpath, new_name) + os.unlink(source) + os.symlink(new_name, source) + + new_name = '{0}_{1}'.format(self._data_dir, postfix) + logger.info('renaming data directory to %s', new_name) + os.rename(self._data_dir, new_name) + except OSError: + logger.exception("Could not rename data directory %s", self._data_dir) + + def remove_data_directory(self): + self.set_role('uninitialized') + logger.info('Removing data directory: %s', self._data_dir) + try: + if os.path.islink(self._data_dir): + os.unlink(self._data_dir) + elif not os.path.exists(self._data_dir): + return + elif os.path.isfile(self._data_dir): + os.remove(self._data_dir) + elif os.path.isdir(self._data_dir): + + # let's see if wal directory is a symlink, in this case we + # should clean the target + for pg_wal_realpath in self.pg_wal_realpath().values(): + logger.info('Removing WAL directory: %s', pg_wal_realpath) + shutil.rmtree(pg_wal_realpath) + + # Remove user defined tablespace directories + for pg_tsp_rpath in self.pg_tblspc_realpaths().values(): + logger.info('Removing user defined tablespace directory: %s', pg_tsp_rpath) + shutil.rmtree(pg_tsp_rpath, ignore_errors=True) + + shutil.rmtree(self._data_dir) + except (IOError, OSError): + logger.exception('Could not remove data directory %s', self._data_dir) + self.move_data_directory() + + def _get_synchronous_commit_param(self): + return self.query("SHOW synchronous_commit").fetchone()[0] + + def pick_synchronous_standby(self, cluster, sync_node_count=1, sync_node_maxlag=-1): + """Finds the best candidate to be the synchronous standby. + + Current synchronous standby is always preferred, unless it has disconnected or does not want to be a + synchronous standby any longer. + Parameter sync_node_maxlag(maximum_lag_on_syncnode) would help swapping unhealthy sync replica incase + if it stops responding (or hung). Please set the value high enough so it won't unncessarily swap sync + standbys during high loads. Any less or equal of 0 value keep the behavior backward compatible and + will not swap. Please note that it will not also swap sync standbys in case where all replicas are hung. + + :returns tuple of candidates list and synchronous standby list. + """ + if self._major_version < 90600: + sync_node_count = 1 + members = {m.name.lower(): m for m in cluster.members} + candidates = [] + sync_nodes = [] + replica_list = [] + # Pick candidates based on who has higher replay/remote_write/flush lsn. + sync_commit_par = self._get_synchronous_commit_param() + sort_col = {'remote_apply': 'replay', 'remote_write': 'write'}.get(sync_commit_par, 'flush') + # pg_stat_replication.sync_state has 4 possible states - async, potential, quorum, sync. + # Sort clause "ORDER BY sync_state DESC" is to get the result in required order and to keep + # the result consistent in case if a synchronous standby member is slowed down OR async node + # receiving changes faster than the sync member (very rare but possible). Such cases would + # trigger sync standby member swapping frequently and the sort on sync_state desc should + # help in keeping the query result consistent. + for app_name, sync_state, replica_lsn in self.query( + "SELECT pg_catalog.lower(application_name), sync_state, pg_{2}_{1}_diff({0}_{1}, '0/0')::bigint" + " FROM pg_catalog.pg_stat_replication" + " WHERE state = 'streaming'" + " ORDER BY sync_state DESC, {0}_{1} DESC".format(sort_col, self.lsn_name, self.wal_name)): + member = members.get(app_name) + if member and not member.tags.get('nosync', False): + replica_list.append((member.name, sync_state, replica_lsn)) + + max_lsn = max(replica_list, key=lambda x: x[2])[2] if len(replica_list) > 1 else int(str(self.last_operation())) + + for app_name, sync_state, replica_lsn in replica_list: + if sync_node_maxlag <= 0 or max_lsn - replica_lsn <= sync_node_maxlag: + candidates.append(app_name) + if sync_state == 'sync': + sync_nodes.append(app_name) + if len(candidates) >= sync_node_count: + break + + return candidates, sync_nodes + + def schedule_sanity_checks_after_pause(self): + """ + After coming out of pause we have to: + 1. configure server parameters if necessary + 2. sync replication slots, because it might happen that slots were removed + 3. get new 'Database system identifier' to make sure that it wasn't changed + """ + if not self._major_version: + self.configure_server_parameters() + self.slots_handler.schedule() + self._sysid = None diff --git a/patroni-for-openGauss/postgresql/bootstrap.py b/patroni-for-openGauss/postgresql/bootstrap.py new file mode 100644 index 0000000000000000000000000000000000000000..ca72fc59f91e407e66de7722db8916a6cfdae279 --- /dev/null +++ b/patroni-for-openGauss/postgresql/bootstrap.py @@ -0,0 +1,381 @@ +import logging +import os +import shlex +import tempfile +import time + +from patroni.dcs import RemoteMember +from patroni.utils import deep_compare +from six import string_types + +logger = logging.getLogger(__name__) + + +class Bootstrap(object): + + def __init__(self, postgresql): + self._postgresql = postgresql + self._running_custom_bootstrap = False + + @property + def running_custom_bootstrap(self): + return self._running_custom_bootstrap + + @property + def keep_existing_recovery_conf(self): + return self._running_custom_bootstrap and self._keep_existing_recovery_conf + + @staticmethod + def process_user_options(tool, options, not_allowed_options, error_handler): + user_options = [] + + def option_is_allowed(name): + ret = name not in not_allowed_options + if not ret: + error_handler('{0} option for {1} is not allowed'.format(name, tool)) + return ret + + if isinstance(options, dict): + for k, v in options.items(): + if k and v: + user_options.append('--{0}={1}'.format(k, v)) + elif isinstance(options, list): + for opt in options: + if isinstance(opt, string_types) and option_is_allowed(opt): + user_options.append('--{0}'.format(opt)) + elif isinstance(opt, dict): + keys = list(opt.keys()) + if len(keys) != 1 or not isinstance(opt[keys[0]], string_types) or not option_is_allowed(keys[0]): + error_handler('Error when parsing {0} key-value option {1}: only one key-value is allowed' + ' and value should be a string'.format(tool, opt[keys[0]])) + user_options.append('--{0}={1}'.format(keys[0], opt[keys[0]])) + else: + error_handler('Error when parsing {0} option {1}: value should be string value' + ' or a single key-value pair'.format(tool, opt)) + else: + error_handler('{0} options must be list ot dict'.format(tool)) + return user_options + + def _initdb(self, config): + self._postgresql.set_state('initalizing new cluster') + not_allowed_options = ('pgdata', 'nosync', 'pwfile', 'sync-only', 'version') + + def error_handler(e): + raise Exception(e) + + options = self.process_user_options('initdb', config or [], not_allowed_options, error_handler) + pwfile = None + + if self._postgresql.config.superuser: + if 'username' in self._postgresql.config.superuser: + options.append('--username={0}'.format(self._postgresql.config.superuser['username'])) + if 'password' in self._postgresql.config.superuser: + (fd, pwfile) = tempfile.mkstemp() + os.write(fd, self._postgresql.config.superuser['password'].encode('utf-8')) + os.close(fd) + options.append('--pwfile={0}'.format(pwfile)) + options = ['-o', ' '.join(options)] if options else [] + + ret = self._postgresql.pg_ctl('initdb', *options) + if pwfile: + os.remove(pwfile) + if ret: + self._postgresql.configure_server_parameters() + else: + self._postgresql.set_state('initdb failed') + return ret + + def _post_restore(self): + self._postgresql.config.restore_configuration_files() + self._postgresql.configure_server_parameters() + + # make sure there is no trigger file or postgres will be automatically promoted + trigger_file = 'promote_trigger_file' if self._postgresql.major_version >= 120000 else 'trigger_file' + trigger_file = self._postgresql.config.get('recovery_conf', {}).get(trigger_file) or 'promote' + trigger_file = os.path.abspath(os.path.join(self._postgresql.data_dir, trigger_file)) + if os.path.exists(trigger_file): + os.unlink(trigger_file) + + def _custom_bootstrap(self, config): + self._postgresql.set_state('running custom bootstrap script') + params = [] if config.get('no_params') else ['--scope=' + self._postgresql.scope, + '--datadir=' + self._postgresql.data_dir] + try: + logger.info('Running custom bootstrap script: %s', config['command']) + if self._postgresql.cancellable.call(shlex.split(config['command']) + params) != 0: + self._postgresql.set_state('custom bootstrap failed') + return False + except Exception: + logger.exception('Exception during custom bootstrap') + return False + self._post_restore() + + if 'recovery_conf' in config: + self._postgresql.config.write_recovery_conf(config['recovery_conf']) + elif not self.keep_existing_recovery_conf: + self._postgresql.config.remove_recovery_conf() + return True + + def call_post_bootstrap(self, config): + """ + runs a script after initdb or custom bootstrap script is called and waits until completion. + """ + cmd = config.get('post_bootstrap') or config.get('post_init') + if cmd: + r = self._postgresql.config.local_connect_kwargs + connstring = self._postgresql.config.format_dsn(r, True) + if 'host' not in r: + # https://www.postgresql.org/docs/current/static/libpq-pgpass.html + # A host name of localhost matches both TCP (host name localhost) and Unix domain socket + # (pghost empty or the default socket directory) connections coming from the local machine. + r['host'] = 'localhost' # set it to localhost to write into pgpass + + env = self._postgresql.config.write_pgpass(r) + env['PGOPTIONS'] = '-c synchronous_commit=local' + + try: + ret = self._postgresql.cancellable.call(shlex.split(cmd) + [connstring], env=env) + except OSError: + logger.error('post_init script %s failed', cmd) + return False + if ret != 0: + logger.error('post_init script %s returned non-zero code %d', cmd, ret) + return False + return True + + def create_replica(self, clone_member): + """ + create the replica according to the replica_method + defined by the user. this is a list, so we need to + loop through all methods the user supplies + """ + + self._postgresql.set_state('creating replica') + self._postgresql.schedule_sanity_checks_after_pause() + + is_remote_master = isinstance(clone_member, RemoteMember) + + # get list of replica methods either from clone member or from + # the config. If there is no configuration key, or no value is + # specified, use basebackup + replica_methods = (clone_member.create_replica_methods if is_remote_master + else self._postgresql.create_replica_methods) or ['basebackup'] + + if clone_member and clone_member.conn_url: + r = clone_member.conn_kwargs(self._postgresql.config.replication) + # add the credentials to connect to the replica origin to pgpass. + env = self._postgresql.config.write_pgpass(r) + connstring = self._postgresql.config.format_dsn(r, True) + else: + connstring = '' + env = os.environ.copy() + # if we don't have any source, leave only replica methods that work without it + replica_methods = [r for r in replica_methods + if self._postgresql.replica_method_can_work_without_replication_connection(r)] + + # go through them in priority order + ret = 1 + for replica_method in replica_methods: + if self._postgresql.cancellable.is_cancelled: + break + + method_config = self._postgresql.replica_method_options(replica_method) + + # if the method is basebackup, then use the built-in + if replica_method == "basebackup": + ret = self.basebackup(connstring, env, method_config) + if ret == 0: + logger.info("replica has been created using basebackup") + # if basebackup succeeds, exit with success + break + else: + if not self._postgresql.data_directory_empty(): + if method_config.get('keep_data', False): + logger.info('Leaving data directory uncleaned') + else: + self._postgresql.remove_data_directory() + + cmd = replica_method + # user-defined method; check for configuration + # not required, actually + if method_config: + # look to see if the user has supplied a full command path + # if not, use the method name as the command + cmd = method_config.pop('command', cmd) + + # add the default parameters + if not method_config.get('no_params', False): + method_config.update({"scope": self._postgresql.scope, + "role": "replica", + "datadir": self._postgresql.data_dir, + "connstring": connstring}) + else: + for param in ('no_params', 'no_master', 'keep_data'): + method_config.pop(param, None) + params = ["--{0}={1}".format(arg, val) for arg, val in method_config.items()] + try: + # call script with the full set of parameters + ret = self._postgresql.cancellable.call(shlex.split(cmd) + params, env=env) + # if we succeeded, stop + if ret == 0: + logger.info('replica has been created using %s', replica_method) + break + else: + logger.error('Error creating replica using method %s: %s exited with code=%s', + replica_method, cmd, ret) + except Exception: + logger.exception('Error creating replica using method %s', replica_method) + ret = 1 + + self._postgresql.set_state('stopped') + return ret + + def basebackup(self, conn_url, env, options): + # creates a replica data dir using pg_basebackup. + # this is the default, built-in create_replica_methods + # tries twice, then returns failure (as 1) + # uses "stream" as the xlog-method to avoid sync issues + # supports additional user-supplied options, those are not validated + maxfailures = 2 + ret = 1 + not_allowed_options = ('pgdata', 'format', 'wal-method', 'xlog-method', 'gzip', + 'version', 'compress', 'dbname', 'host', 'port', 'username', 'password') + user_options = self.process_user_options('basebackup', options, not_allowed_options, logger.error) + + for bbfailures in range(0, maxfailures): + if self._postgresql.cancellable.is_cancelled: + break + if not self._postgresql.data_directory_empty(): + self._postgresql.remove_data_directory() + try: + ret = self._postgresql.cancellable.call([self._postgresql.pgcommand('pg_basebackup'), + '--pgdata=' + self._postgresql.data_dir, '-X', 'stream', + '--dbname=' + conn_url] + user_options, env=env) + if ret == 0: + break + else: + logger.error('Error when fetching backup: pg_basebackup exited with code=%s', ret) + + except Exception as e: + logger.error('Error when fetching backup with pg_basebackup: %s', e) + + if bbfailures < maxfailures - 1: + logger.warning('Trying again in 5 seconds') + time.sleep(5) + + return ret + + def clone(self, clone_member): + """ + - initialize the replica from an existing member (master or replica) + - initialize the replica using the replica creation method that + works without the replication connection (i.e. restore from on-disk + base backup) + """ + + ret = self.create_replica(clone_member) == 0 + if ret: + self._post_restore() + return ret + + def bootstrap(self, config): + """ Initialize a new node from scratch and start it. """ + pg_hba = config.get('pg_hba', []) + method = config.get('method') or 'initdb' + if method != 'initdb' and method in config and 'command' in config[method]: + self._keep_existing_recovery_conf = config[method].get('keep_existing_recovery_conf') + self._running_custom_bootstrap = True + do_initialize = self._custom_bootstrap + else: + method = 'initdb' + do_initialize = self._initdb + return do_initialize(config.get(method)) and self._postgresql.config.append_pg_hba(pg_hba) \ + and self._postgresql.config.save_configuration_files() and self._postgresql.start() + + def create_or_update_role(self, name, password, options): + options = list(map(str.upper, options)) + if 'NOLOGIN' not in options and 'LOGIN' not in options: + options.append('LOGIN') + + params = [name] + if password: + options.extend(['PASSWORD', '%s']) + params.extend([password, password]) + + sql = """DO $$ +BEGIN + SET local synchronous_commit = 'local'; + PERFORM * FROM pg_authid WHERE rolname = %s; + IF FOUND THEN + ALTER ROLE "{0}" WITH {1}; + ELSE + CREATE ROLE "{0}" WITH {1}; + END IF; +END;$$""".format(name, ' '.join(options)) + self._postgresql.query('SET log_statement TO none') + self._postgresql.query('SET log_min_duration_statement TO -1') + self._postgresql.query("SET log_min_error_statement TO 'log'") + try: + self._postgresql.query(sql, *params) + finally: + self._postgresql.query('RESET log_min_error_statement') + self._postgresql.query('RESET log_min_duration_statement') + self._postgresql.query('RESET log_statement') + + def post_bootstrap(self, config, task): + try: + postgresql = self._postgresql + superuser = postgresql.config.superuser + if 'username' in superuser and 'password' in superuser: + self.create_or_update_role(superuser['username'], superuser['password'], ['SUPERUSER']) + + task.complete(self.call_post_bootstrap(config)) + if task.result: + replication = postgresql.config.replication + self.create_or_update_role(replication['username'], replication.get('password'), ['REPLICATION']) + + rewind = postgresql.config.rewind_credentials + if not deep_compare(rewind, superuser): + self.create_or_update_role(rewind['username'], rewind.get('password'), []) + for f in ('pg_ls_dir(text, boolean, boolean)', 'pg_stat_file(text, boolean)', + 'pg_read_binary_file(text)', 'pg_read_binary_file(text, bigint, bigint, boolean)'): + sql = """DO $$ +BEGIN + SET local synchronous_commit = 'local'; + GRANT EXECUTE ON function pg_catalog.{0} TO "{1}"; +END;$$""".format(f, rewind['username']) + postgresql.query(sql) + + for name, value in (config.get('users') or {}).items(): + if all(name != a.get('username') for a in (superuser, replication, rewind)): + self.create_or_update_role(name, value.get('password'), value.get('options', [])) + + # We were doing a custom bootstrap instead of running initdb, therefore we opened trust + # access from certain addresses to be able to reach cluster and change password + if self._running_custom_bootstrap: + self._running_custom_bootstrap = False + # If we don't have custom configuration for pg_hba.conf we need to restore original file + if not postgresql.config.get('pg_hba'): + if os.path.exists(postgresql.config.pg_hba_conf): + os.unlink(postgresql.config.pg_hba_conf) + postgresql.config.restore_configuration_files() + postgresql.config.write_postgresql_conf() + postgresql.config.replace_pg_ident() + + # at this point there should be no recovery.conf + postgresql.config.remove_recovery_conf() + + if postgresql.config.hba_file: + postgresql.restart() + else: + postgresql.config.replace_pg_hba() + if postgresql.pending_restart: + postgresql.restart() + else: + postgresql.reload() + time.sleep(1) # give a time to postgres to "reload" configuration files + postgresql.connection().close() # close connection to reconnect with a new password + except Exception: + logger.exception('post_bootstrap') + task.complete(False) + return task.result diff --git a/patroni-for-openGauss/postgresql/callback_executor.py b/patroni-for-openGauss/postgresql/callback_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..40ebf5a10717698e45f230b2eeefca58e3ceb089 --- /dev/null +++ b/patroni-for-openGauss/postgresql/callback_executor.py @@ -0,0 +1,36 @@ +import logging + +from patroni.postgresql.cancellable import CancellableExecutor +from threading import Condition, Thread + +logger = logging.getLogger(__name__) + + +class CallbackExecutor(CancellableExecutor, Thread): + + def __init__(self): + CancellableExecutor.__init__(self) + Thread.__init__(self) + self.daemon = True + self._cmd = None + self._condition = Condition() + self.start() + + def call(self, cmd): + self._kill_process() + with self._condition: + self._cmd = cmd + self._condition.notify() + + def run(self): + while True: + with self._condition: + if self._cmd is None: + self._condition.wait() + cmd, self._cmd = self._cmd, None + + with self._lock: + if not self._start_process(cmd, close_fds=True): + continue + self._process.wait() + self._kill_children() diff --git a/patroni-for-openGauss/postgresql/cancellable.py b/patroni-for-openGauss/postgresql/cancellable.py new file mode 100644 index 0000000000000000000000000000000000000000..8e46fbfb2af7040b26a058dc9af757e75dcad9e1 --- /dev/null +++ b/patroni-for-openGauss/postgresql/cancellable.py @@ -0,0 +1,133 @@ +import logging +import psutil +import subprocess + +from patroni.exceptions import PostgresException +from patroni.utils import polling_loop +from threading import Lock + +logger = logging.getLogger(__name__) + + +class CancellableExecutor(object): + + """ + There must be only one such process so that AsyncExecutor can easily cancel it. + """ + + def __init__(self): + self._process = None + self._process_cmd = None + self._process_children = [] + self._lock = Lock() + + def _start_process(self, cmd, *args, **kwargs): + """This method must be executed only when the `_lock` is acquired""" + + try: + self._process_children = [] + self._process_cmd = cmd + self._process = psutil.Popen(cmd, *args, **kwargs) + except Exception: + return logger.exception('Failed to execute %s', cmd) + return True + + def _kill_process(self): + with self._lock: + if self._process is not None and self._process.is_running() and not self._process_children: + try: + self._process.suspend() # Suspend the process before getting list of childrens + except psutil.Error as e: + logger.info('Failed to suspend the process: %s', e.msg) + + try: + self._process_children = self._process.children(recursive=True) + except psutil.Error: + pass + + try: + self._process.kill() + logger.warning('Killed %s because it was still running', self._process_cmd) + except psutil.NoSuchProcess: + pass + except psutil.AccessDenied as e: + logger.warning('Failed to kill the process: %s', e.msg) + + def _kill_children(self): + waitlist = [] + with self._lock: + for child in self._process_children: + try: + child.kill() + except psutil.NoSuchProcess: + continue + except psutil.AccessDenied as e: + logger.info('Failed to kill child process: %s', e.msg) + waitlist.append(child) + psutil.wait_procs(waitlist) + + +class CancellableSubprocess(CancellableExecutor): + + def __init__(self): + super(CancellableSubprocess, self).__init__() + self._is_cancelled = False + + def call(self, *args, **kwargs): + for s in ('stdin', 'stdout', 'stderr'): + kwargs.pop(s, None) + + communicate = kwargs.pop('communicate', None) + if isinstance(communicate, dict): + input_data = communicate.get('input') + if input_data: + if input_data[-1] != '\n': + input_data += '\n' + input_data = input_data.encode('utf-8') + kwargs['stdin'] = subprocess.PIPE + kwargs['stdout'] = subprocess.PIPE + kwargs['stderr'] = subprocess.PIPE + + try: + with self._lock: + if self._is_cancelled: + raise PostgresException('cancelled') + + self._is_cancelled = False + started = self._start_process(*args, **kwargs) + + if started: + if isinstance(communicate, dict): + communicate['stdout'], communicate['stderr'] = self._process.communicate(input_data) + return self._process.wait() + finally: + with self._lock: + self._process = None + self._kill_children() + + def reset_is_cancelled(self): + with self._lock: + self._is_cancelled = False + + @property + def is_cancelled(self): + with self._lock: + return self._is_cancelled + + def cancel(self, kill=False): + with self._lock: + self._is_cancelled = True + if self._process is None or not self._process.is_running(): + return + + logger.info('Terminating %s', self._process_cmd) + self._process.terminate() + + for _ in polling_loop(10): + with self._lock: + if self._process is None or not self._process.is_running(): + return + if kill: + break + + self._kill_process() diff --git a/patroni-for-openGauss/postgresql/config.py b/patroni-for-openGauss/postgresql/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0f21b288b88de0bfa9a8a33aa47e549d17912058 --- /dev/null +++ b/patroni-for-openGauss/postgresql/config.py @@ -0,0 +1,1126 @@ +import logging +import os +import re +import shutil +import socket +import stat +import time + +from six.moves.urllib_parse import urlparse, parse_qsl, unquote + +from .validator import CaseInsensitiveDict, recovery_parameters,\ + transform_postgresql_parameter_value, transform_recovery_parameter_value +from ..dcs import slot_name_from_member_name, RemoteMember +from ..exceptions import PatroniFatalException +from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, \ + validate_directory, is_subpath + +logger = logging.getLogger(__name__) + +SYNC_STANDBY_NAME_RE = re.compile(r'^[A-Za-z_][A-Za-z_0-9\$]*$') +PARAMETER_RE = re.compile(r'([a-z_]+)\s*=\s*') + + +def quote_ident(value): + """Very simplified version of quote_ident""" + return value if SYNC_STANDBY_NAME_RE.match(value) else '"' + value + '"' + + +def conninfo_uri_parse(dsn): + ret = {} + r = urlparse(dsn) + if r.username: + ret['user'] = r.username + if r.password: + ret['password'] = r.password + if r.path[1:]: + ret['dbname'] = r.path[1:] + hosts = [] + ports = [] + for netloc in r.netloc.split('@')[-1].split(','): + host = port = None + if '[' in netloc and ']' in netloc: + host = netloc.split(']')[0][1:] + tmp = netloc.split(':', 1) + if host is None: + host = tmp[0] + if len(tmp) == 2: + host, port = tmp + if host is not None: + hosts.append(host) + if port is not None: + ports.append(port) + if hosts: + ret['host'] = ','.join(hosts) + if ports: + ret['port'] = ','.join(ports) + ret = {name: unquote(value) for name, value in ret.items()} + ret.update({name: value for name, value in parse_qsl(r.query)}) + if ret.get('ssl') == 'true': + del ret['ssl'] + ret['sslmode'] = 'require' + return ret + + +def read_param_value(value): + length = len(value) + ret = '' + is_quoted = value[0] == "'" + i = int(is_quoted) + while i < length: + if is_quoted: + if value[i] == "'": + return ret, i + 1 + elif value[i].isspace(): + break + if value[i] == '\\': + i += 1 + if i >= length: + break + ret += value[i] + i += 1 + return (None, None) if is_quoted else (ret, i) + + +def conninfo_parse(dsn): + ret = {} + length = len(dsn) + i = 0 + while i < length: + if dsn[i].isspace(): + i += 1 + continue + + param_match = PARAMETER_RE.match(dsn[i:]) + if not param_match: + return + + param = param_match.group(1) + i += param_match.end() + + if i >= length: + return + + value, end = read_param_value(dsn[i:]) + if value is None: + return + i += end + ret[param] = value + return ret + + +def parse_dsn(value): + """ + Very simple equivalent of `psycopg2.extensions.parse_dsn` introduced in 2.7.0. + We are not using psycopg2 function in order to remain compatible with 2.5.4+. + There is one minor difference though, this function removes `dbname` from the result + and sets the `sslmode`, 'gssencmode', and `channel_binding` to `prefer` if it is not present in + the connection string. This is necessary to simplify comparison of the old and the new values. + + >>> r = parse_dsn('postgresql://u%2Fse:pass@:%2f123,[%2Fhost2]/db%2Fsdf?application_name=mya%2Fpp&ssl=true') + >>> r == {'application_name': 'mya/pp', 'host': ',/host2', 'sslmode': 'require',\ + 'password': 'pass', 'port': '/123', 'user': 'u/se', 'gssencmode': 'prefer', 'channel_binding': 'prefer'} + True + >>> r = parse_dsn(" host = 'host' dbname = db\\\\ name requiressl=1 ") + >>> r == {'host': 'host', 'sslmode': 'require', 'gssencmode': 'prefer', 'channel_binding': 'prefer'} + True + >>> parse_dsn('requiressl = 0\\\\') == {'sslmode': 'prefer', 'gssencmode': 'prefer', 'channel_binding': 'prefer'} + True + >>> parse_dsn("host=a foo = '") is None + True + >>> parse_dsn("host=a foo = ") is None + True + >>> parse_dsn("1") is None + True + """ + if value.startswith('postgres://') or value.startswith('postgresql://'): + ret = conninfo_uri_parse(value) + else: + ret = conninfo_parse(value) + + if ret: + if 'sslmode' not in ret: # allow sslmode to take precedence over requiressl + requiressl = ret.pop('requiressl', None) + if requiressl == '1': + ret['sslmode'] = 'require' + elif requiressl is not None: + ret['sslmode'] = 'prefer' + ret.setdefault('sslmode', 'prefer') + if 'dbname' in ret: + del ret['dbname'] + ret.setdefault('gssencmode', 'prefer') + ret.setdefault('channel_binding', 'prefer') + return ret + + +def strip_comment(value): + i = value.find('#') + if i > -1: + value = value[:i].strip() + return value + + +def read_recovery_param_value(value): + """ + >>> read_recovery_param_value('') is None + True + >>> read_recovery_param_value("'") is None + True + >>> read_recovery_param_value("''a") is None + True + >>> read_recovery_param_value('a b') is None + True + >>> read_recovery_param_value("'''") is None + True + >>> read_recovery_param_value("'\\\\") is None + True + >>> read_recovery_param_value("'a' s#") is None + True + >>> read_recovery_param_value("'\\\\'''' #a") + "''" + >>> read_recovery_param_value('asd') + 'asd' + """ + value = value.strip() + length = len(value) + if length == 0: + return None + elif value[0] == "'": + if length == 1: + return None + ret = '' + i = 1 + while i < length: + if value[i] == '\\': + i += 1 + if i >= length: + return None + elif value[i] == "'": + i += 1 + if i >= length: + break + if value[i] in ('#', ' '): + if strip_comment(value[i:]): + return None + break + if value[i] != "'": + return None + ret += value[i] + i += 1 + else: + return None + return ret + else: + value = strip_comment(value) + if not value or ' ' in value or '\\' in value: + return None + return value + + +def mtime(filename): + try: + return os.stat(filename).st_mtime + except OSError: + return None + + +class ConfigWriter(object): + + def __init__(self, filename): + self._filename = filename + self._fd = None + + def __enter__(self): + self._fd = open(self._filename, 'w') + self.writeline('# Do not edit this file manually!\n# It will be overwritten by Patroni!') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._fd: + self._fd.close() + + def writeline(self, line): + self._fd.write(line) + self._fd.write('\n') + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + @staticmethod + def escape(value): # Escape (by doubling) any single quotes or backslashes in given string + return re.sub(r'([\'\\])', r'\1\1', str(value)) + + def write_param(self, param, value): + self.writeline("{0} = '{1}'".format(param, self.escape(value))) + + +class ConfigHandler(object): + + # List of parameters which must be always passed to postmaster as command line options + # to make it not possible to change them with 'ALTER SYSTEM'. + # Some of these parameters have sane default value assigned and Patroni doesn't allow + # to decrease this value. E.g. 'wal_level' can't be lower then 'hot_standby' and so on. + # These parameters could be changed only globally, i.e. via DCS. + # P.S. 'listen_addresses' and 'port' are added here just for convenience, to mark them + # as a parameters which should always be passed through command line. + # + # Format: + # key - parameter name + # value - tuple(default_value, check_function, min_version) + # default_value -- some sane default value + # check_function -- if the new value is not correct must return `!False` + # min_version -- major version of PostgreSQL when parameter was introduced + CMDLINE_OPTIONS = CaseInsensitiveDict({ + 'listen_addresses': (None, lambda _: False, 90100), + 'port': (None, lambda _: False, 90100), + 'cluster_name': (None, lambda _: False, 90500), + 'wal_level': ('hot_standby', lambda v: v.lower() in ('hot_standby', 'replica', 'logical'), 90100), + 'hot_standby': ('on', lambda _: False, 90100), + 'max_connections': (100, lambda v: int(v) >= 25, 90100), + 'max_wal_senders': (10, lambda v: int(v) >= 3, 90100), + 'wal_keep_segments': (8, lambda v: int(v) >= 1, 90100), + 'wal_keep_size': ('128MB', lambda v: parse_int(v, 'MB') >= 16, 130000), + 'max_prepared_transactions': (0, lambda v: int(v) >= 0, 90100), + 'max_locks_per_transaction': (64, lambda v: int(v) >= 32, 90100), + 'track_commit_timestamp': ('off', lambda v: parse_bool(v) is not None, 90500), + 'max_replication_slots': (10, lambda v: int(v) >= 4, 90100), + 'max_worker_processes': (8, lambda v: int(v) >= 2, 90400), + 'wal_log_hints': ('on', lambda _: False, 90100) + }) + + _RECOVERY_PARAMETERS = set(recovery_parameters.keys()) + + def __init__(self, postgresql, config): + self._postgresql = postgresql + self._config_dir = os.path.abspath(config.get('config_dir') or postgresql.data_dir) + config_base_name = config.get('config_base_name', 'postgresql') + self._postgresql_conf = os.path.join(self._config_dir, config_base_name + '.conf') + self._postgresql_conf_mtime = None + self._postgresql_base_conf_name = config_base_name + '.base.conf' + self._postgresql_base_conf = os.path.join(self._config_dir, self._postgresql_base_conf_name) + self._pg_hba_conf = os.path.join(self._config_dir, 'pg_hba.conf') + self._pg_ident_conf = os.path.join(self._config_dir, 'pg_ident.conf') + self._recovery_conf = os.path.join(postgresql.data_dir, 'recovery.conf') + self._recovery_conf_mtime = None + self._recovery_signal = os.path.join(postgresql.data_dir, 'recovery.signal') + self._standby_signal = os.path.join(postgresql.data_dir, 'standby.signal') + self._auto_conf = os.path.join(postgresql.data_dir, 'postgresql.auto.conf') + self._auto_conf_mtime = None + self._pgpass = os.path.abspath(config.get('pgpass') or os.path.join(os.path.expanduser('~'), 'pgpass')) + if os.path.exists(self._pgpass) and not os.path.isfile(self._pgpass): + raise PatroniFatalException("'{0}' exists and it's not a file, check your `postgresql.pgpass` configuration" + .format(self._pgpass)) + self._passfile = None + self._passfile_mtime = None + self._synchronous_standby_names = None + self._postmaster_ctime = None + self._current_recovery_params = None + self._config = {} + self._recovery_params = {} + self.reload_config(config) + + def setup_server_parameters(self): + self._server_parameters = self.get_server_parameters(self._config) + self._adjust_recovery_parameters() + + def try_to_create_dir(self, d, msg): + d = os.path.join(self._postgresql._data_dir, d) + if (not is_subpath(self._postgresql._data_dir, d) or not self._postgresql.data_directory_empty()): + validate_directory(d, msg) + + def check_directories(self): + if "unix_socket_directories" in self._server_parameters: + for d in self._server_parameters["unix_socket_directories"].split(","): + self.try_to_create_dir(d.strip(), "'{}' is defined in unix_socket_directories, {}") + if "stats_temp_directory" in self._server_parameters: + self.try_to_create_dir(self._server_parameters["stats_temp_directory"], + "'{}' is defined in stats_temp_directory, {}") + self.try_to_create_dir(os.path.dirname(self._pgpass), + "'{}' is defined in `postgresql.pgpass`, {}") + + @property + def _configuration_to_save(self): + configuration = [os.path.basename(self._postgresql_conf)] + if 'custom_conf' not in self._config: + configuration.append(os.path.basename(self._postgresql_base_conf_name)) + if not self.hba_file: + configuration.append('pg_hba.conf') + if not self.ident_file: + configuration.append('pg_ident.conf') + return configuration + + def save_configuration_files(self, check_custom_bootstrap=False): + """ + copy postgresql.conf to postgresql.conf.backup to be able to retrive configuration files + - originally stored as symlinks, those are normally skipped by pg_basebackup + - in case of WAL-E basebackup (see http://comments.gmane.org/gmane.comp.db.postgresql.wal-e/239) + """ + if not (check_custom_bootstrap and self._postgresql.bootstrap.running_custom_bootstrap): + try: + for f in self._configuration_to_save: + config_file = os.path.join(self._config_dir, f) + backup_file = os.path.join(self._postgresql.data_dir, f + '.backup') + if os.path.isfile(config_file): + shutil.copy(config_file, backup_file) + except IOError: + logger.exception('unable to create backup copies of configuration files') + return True + + def restore_configuration_files(self): + """ restore a previously saved postgresql.conf """ + try: + for f in self._configuration_to_save: + config_file = os.path.join(self._config_dir, f) + backup_file = os.path.join(self._postgresql.data_dir, f + '.backup') + if not os.path.isfile(config_file): + if os.path.isfile(backup_file): + shutil.copy(backup_file, config_file) + # Previously we didn't backup pg_ident.conf, if file is missing just create empty + elif f == 'pg_ident.conf': + open(config_file, 'w').close() + except IOError: + logger.exception('unable to restore configuration files from backup') + + def write_postgresql_conf(self, configuration=None): + # rename the original configuration if it is necessary + if 'custom_conf' not in self._config and not os.path.exists(self._postgresql_base_conf): + os.rename(self._postgresql_conf, self._postgresql_base_conf) + else: + return + + # In case we are using custom bootstrap from spilo image with PITR it fails if it contains increasing + # values like Max_connections. We disable hot_standby so it will accept increasing values. + if self._postgresql.bootstrap.running_custom_bootstrap: + configuration['hot_standby'] = 'off' + + with ConfigWriter(self._postgresql_conf) as f: + include = self._config.get('custom_conf') or self._postgresql_base_conf_name + f.writeline("include '{0}'\n".format(ConfigWriter.escape(include))) + for name, value in sorted((configuration or self._server_parameters).items()): + value = transform_postgresql_parameter_value(self._postgresql.major_version, name, value) + if (not self._postgresql.bootstrap.running_custom_bootstrap or name != 'hba_file') \ + and name not in self._RECOVERY_PARAMETERS and value is not None: + f.write_param(name, value) + # when we are doing custom bootstrap we assume that we don't know superuser password + # and in order to be able to change it, we are opening trust access from a certain address + # therefore we need to make sure that hba_file is not overriden + # after changing superuser password we will "revert" all these "changes" + if self._postgresql.bootstrap.running_custom_bootstrap or 'hba_file' not in self._server_parameters: + f.write_param('hba_file', self._pg_hba_conf) + if 'ident_file' not in self._server_parameters: + f.write_param('ident_file', self._pg_ident_conf) + + if self._postgresql.major_version >= 120000: + if self._recovery_params: + f.writeline('\n# recovery.conf') + self._write_recovery_params(f, self._recovery_params) + + if not self._postgresql.bootstrap.keep_existing_recovery_conf: + self._sanitize_auto_conf() + + def append_pg_hba(self, config): + if not self.hba_file and not self._config.get('pg_hba'): + with open(self._pg_hba_conf, 'a') as f: + f.write('\n{}\n'.format('\n'.join(config))) + return True + + def replace_pg_hba(self): + """ + Replace pg_hba.conf content in the PGDATA if hba_file is not defined in the + `postgresql.parameters` and pg_hba is defined in `postgresql` configuration section. + + :returns: True if pg_hba.conf was rewritten. + """ + + # when we are doing custom bootstrap we assume that we don't know superuser password + # and in order to be able to change it, we are opening trust access from a certain address + if self._postgresql.bootstrap.running_custom_bootstrap: + addresses = {} if os.name == 'nt' else {'': 'local'} # windows doesn't yet support unix-domain sockets + if 'host' in self.local_replication_address and not self.local_replication_address['host'].startswith('/'): + addresses.update({sa[0] + '/32': 'host' for _, _, _, _, sa in socket.getaddrinfo( + self.local_replication_address['host'], self.local_replication_address['port'], + 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)}) + + with ConfigWriter(self._pg_hba_conf) as f: + for address, t in addresses.items(): + f.writeline(( + '{0}\treplication\t{1}\t{3}\ttrust\n' + '{0}\tall\t{2}\t{3}\ttrust' + ).format(t, self.replication['username'], self._superuser.get('username') or 'all', address)) + elif not self.hba_file and self._config.get('pg_hba'): + with ConfigWriter(self._pg_hba_conf) as f: + f.writelines(self._config['pg_hba']) + return True + + def replace_pg_ident(self): + """ + Replace pg_ident.conf content in the PGDATA if ident_file is not defined in the + `postgresql.parameters` and pg_ident is defined in the `postgresql` section. + + :returns: True if pg_ident.conf was rewritten. + """ + + if not self.ident_file and self._config.get('pg_ident'): + with ConfigWriter(self._pg_ident_conf) as f: + f.writelines(self._config['pg_ident']) + return True + + def primary_conninfo_params(self, member): + if not (member and member.conn_url) or member.name == self._postgresql.name: + return None + ret = member.conn_kwargs(self.replication) + ret['application_name'] = self._postgresql.name + ret.setdefault('sslmode', 'prefer') + if self._postgresql.major_version >= 120000: + ret.setdefault('gssencmode', 'prefer') + if self._postgresql.major_version >= 130000: + ret.setdefault('channel_binding', 'prefer') + if self._krbsrvname: + ret['krbsrvname'] = self._krbsrvname + if 'database' in ret: + del ret['database'] + return ret + + def format_dsn(self, params, include_dbname=False): + # A list of keywords that can be found in a conninfo string. Follows what is acceptable by libpq + keywords = ('dbname', 'user', 'passfile' if params.get('passfile') else 'password', 'host', 'port', + 'sslmode', 'sslcompression', 'sslcert', 'sslkey', 'sslpassword', 'sslrootcert', 'sslcrl', + 'application_name', 'krbsrvname', 'gssencmode', 'channel_binding') + if include_dbname: + params = params.copy() + params['dbname'] = params.get('database') or self._postgresql.database + # we are abusing information about the necessity of dbname + # dsn should contain passfile or password only if there is no dbname in it (it is used in recovery.conf) + skip = {'passfile', 'password'} + else: + skip = {'dbname'} + + def escape(value): + return re.sub(r'([\'\\ ])', r'\\\1', str(value)) + + return ' '.join('{0}={1}'.format(kw, escape(params[kw])) for kw in keywords + if kw not in skip and params.get(kw) is not None) + + def _write_recovery_params(self, fd, recovery_params): + if self._postgresql.major_version >= 90500: + pause_at_recovery_target = parse_bool(recovery_params.pop('pause_at_recovery_target', None)) + if pause_at_recovery_target is not None: + recovery_params.setdefault('recovery_target_action', 'pause' if pause_at_recovery_target else 'promote') + else: + if str(recovery_params.pop('recovery_target_action', None)).lower() == 'promote': + recovery_params.setdefault('pause_at_recovery_target', 'false') + for name, value in sorted(recovery_params.items()): + if name == 'primary_conninfo': + if 'password' in value and self._postgresql.major_version >= 100000: + self.write_pgpass(value) + value['passfile'] = self._passfile = self._pgpass + self._passfile_mtime = mtime(self._pgpass) + value = self.format_dsn(value) + else: + value = transform_recovery_parameter_value(self._postgresql.major_version, name, value) + if value is None: + continue + fd.write_param(name, value) + + def build_recovery_params(self, member): + recovery_params = CaseInsensitiveDict({p: v for p, v in self.get('recovery_conf', {}).items() + if not p.lower().startswith('recovery_target') and + p.lower() not in ('primary_conninfo', 'primary_slot_name')}) + recovery_params.update({'standby_mode': 'on', 'recovery_target_timeline': 'latest'}) + if self._postgresql.major_version >= 120000: + # on pg12 we want to protect from following params being set in one of included files + # not doing so might result in a standby being paused, promoted or shutted down. + recovery_params.update({'recovery_target': '', 'recovery_target_name': '', 'recovery_target_time': '', + 'recovery_target_xid': '', 'recovery_target_lsn': ''}) + + is_remote_master = isinstance(member, RemoteMember) + primary_conninfo = self.primary_conninfo_params(member) + if primary_conninfo: + use_slots = self.get('use_slots', True) and self._postgresql.major_version >= 90400 + if use_slots and not (is_remote_master and member.no_replication_slot): + primary_slot_name = member.primary_slot_name if is_remote_master else self._postgresql.name + recovery_params['primary_slot_name'] = slot_name_from_member_name(primary_slot_name) + recovery_params['primary_conninfo'] = primary_conninfo + + # standby_cluster config might have different parameters, we want to override them + standby_cluster_params = ['restore_command', 'archive_cleanup_command']\ + + (['recovery_min_apply_delay'] if is_remote_master else []) + recovery_params.update({p: member.data.get(p) for p in standby_cluster_params if member and member.data.get(p)}) + return recovery_params + + def recovery_conf_exists(self): + if self._postgresql.major_version >= 120000: + return os.path.exists(self._standby_signal) or os.path.exists(self._recovery_signal) + return os.path.exists(self._recovery_conf) + + @property + def _triggerfile_good_name(self): + return 'trigger_file' if self._postgresql.major_version < 120000 else 'promote_trigger_file' + + @property + def _triggerfile_wrong_name(self): + return 'trigger_file' if self._postgresql.major_version >= 120000 else 'promote_trigger_file' + + @property + def _recovery_parameters_to_compare(self): + skip_params = {'pause_at_recovery_target', 'recovery_target_inclusive', + 'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name} + return self._RECOVERY_PARAMETERS - skip_params + + def _read_recovery_params(self): + pg_conf_mtime = mtime(self._postgresql_conf) + auto_conf_mtime = mtime(self._auto_conf) + passfile_mtime = mtime(self._passfile) if self._passfile else False + postmaster_ctime = self._postgresql.is_running() + if postmaster_ctime: + postmaster_ctime = postmaster_ctime.create_time() + + if self._postgresql_conf_mtime == pg_conf_mtime and self._auto_conf_mtime == auto_conf_mtime \ + and self._passfile_mtime == passfile_mtime and self._postmaster_ctime == postmaster_ctime: + return None, False + + try: + values = self._get_pg_settings(self._recovery_parameters_to_compare).values() + values = {p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values} + self._postgresql_conf_mtime = pg_conf_mtime + self._auto_conf_mtime = auto_conf_mtime + self._postmaster_ctime = postmaster_ctime + except Exception: + values = None + return values, True + + def _read_recovery_params_pre_v12(self): + recovery_conf_mtime = mtime(self._recovery_conf) + passfile_mtime = mtime(self._passfile) if self._passfile else False + if recovery_conf_mtime == self._recovery_conf_mtime and passfile_mtime == self._passfile_mtime: + return None, False + + values = {} + with open(self._recovery_conf, 'r') as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + continue + value = None + match = PARAMETER_RE.match(line) + if match: + value = read_recovery_param_value(line[match.end():]) + if value is None: + return None, True + values[match.group(1)] = [value, True] + self._recovery_conf_mtime = recovery_conf_mtime + values.setdefault('recovery_min_apply_delay', ['0', True]) + values['recovery_min_apply_delay'][0] = parse_int(values['recovery_min_apply_delay'][0], 'ms') + values.update({param: ['', True] for param in self._recovery_parameters_to_compare if param not in values}) + return values, True + + def _check_passfile(self, passfile, wanted_primary_conninfo): + # If there is a passfile in the primary_conninfo try to figure out that + # the passfile contains the line allowing connection to the given node. + # We assume that the passfile was created by Patroni and therefore doing + # the full match and not covering cases when host, port or user are set to '*' + passfile_mtime = mtime(passfile) + if passfile_mtime: + try: + with open(passfile) as f: + wanted_line = self._pgpass_line(wanted_primary_conninfo).strip() + for raw_line in f: + if raw_line.strip() == wanted_line: + self._passfile = passfile + self._passfile_mtime = passfile_mtime + return True + except Exception: + logger.info('Failed to read %s', passfile) + return False + + def _check_primary_conninfo(self, primary_conninfo, wanted_primary_conninfo): + # first we will cover corner cases, when we are replicating from somewhere while shouldn't + # or there is no primary_conninfo but we should replicate from some specific node. + if not wanted_primary_conninfo: + return not primary_conninfo + elif not primary_conninfo: + return False + + wal_receiver_primary_conninfo = self._postgresql.primary_conninfo() + if wal_receiver_primary_conninfo: + wal_receiver_primary_conninfo = parse_dsn(wal_receiver_primary_conninfo) + # when wal receiver is alive use primary_conninfo from pg_stat_wal_receiver for comparison + if wal_receiver_primary_conninfo: + primary_conninfo = wal_receiver_primary_conninfo + # There could be no password in the primary_conninfo or it is masked. + # Just copy the "desired" value in order to make comparison succeed. + if 'password' in wanted_primary_conninfo: + primary_conninfo['password'] = wanted_primary_conninfo['password'] + + if 'passfile' in primary_conninfo and 'password' not in primary_conninfo \ + and 'password' in wanted_primary_conninfo: + if self._check_passfile(primary_conninfo['passfile'], wanted_primary_conninfo): + primary_conninfo['password'] = wanted_primary_conninfo['password'] + else: + return False + + return all(primary_conninfo.get(p) == str(v) for p, v in wanted_primary_conninfo.items() if v is not None) + + def check_recovery_conf(self, member): + """Returns a tuple. The first boolean element indicates that recovery params don't match + and the second is set to `True` if the restart is required in order to apply new values""" + + # TODO: recovery.conf could be stale, would be nice to detect that. + if self._postgresql.major_version >= 120000: + if not os.path.exists(self._standby_signal): + return True, True + + _read_recovery_params = self._read_recovery_params + else: + if not self.recovery_conf_exists(): + return True, True + + _read_recovery_params = self._read_recovery_params_pre_v12 + + params, updated = _read_recovery_params() + # updated indicates that mtime of postgresql.conf, postgresql.auto.conf, or recovery.conf + # was changed and params were read either from the config or from the database connection. + if updated: + if params is None: # exception or unparsable config + return True, True + + # We will cache parsed value until the next config change. + self._current_recovery_params = params + primary_conninfo = params['primary_conninfo'] + if primary_conninfo[0]: + primary_conninfo[0] = parse_dsn(params['primary_conninfo'][0]) + # If we failed to parse non-empty connection string this indicates that config if broken. + if not primary_conninfo[0]: + return True, True + else: # empty string, primary_conninfo is not in the config + primary_conninfo[0] = {} + + # when wal receiver is alive take primary_slot_name from pg_stat_wal_receiver + wal_receiver_primary_slot_name = self._postgresql.primary_slot_name() + if not wal_receiver_primary_slot_name and self._postgresql.primary_conninfo(): + wal_receiver_primary_slot_name = '' + if wal_receiver_primary_slot_name is not None: + self._current_recovery_params['primary_slot_name'][0] = wal_receiver_primary_slot_name + + required = {'restart': 0, 'reload': 0} + + def record_missmatch(mtype): + required['restart' if mtype else 'reload'] += 1 + + wanted_recovery_params = self.build_recovery_params(member) + for param, value in self._current_recovery_params.items(): + # Skip certain parameters defined in the included postgres config files + # if we know that they are not specified in the patroni configuration. + if len(value) > 2 and value[2] not in (self._postgresql_conf, self._auto_conf) and \ + param in ('archive_cleanup_command', 'promote_trigger_file', 'recovery_end_command', + 'recovery_min_apply_delay', 'restore_command') and param not in wanted_recovery_params: + continue + if param == 'recovery_min_apply_delay': + if not compare_values('integer', 'ms', value[0], wanted_recovery_params.get(param, 0)): + record_missmatch(value[1]) + elif param == 'standby_mode': + if not compare_values('bool', None, value[0], wanted_recovery_params.get(param, 'on')): + record_missmatch(value[1]) + elif param == 'primary_conninfo': + if not self._check_primary_conninfo(value[0], wanted_recovery_params.get('primary_conninfo', {})): + record_missmatch(value[1]) + elif (param != 'primary_slot_name' or wanted_recovery_params.get('primary_conninfo')) \ + and str(value[0]) != str(wanted_recovery_params.get(param, '')): + record_missmatch(value[1]) + return required['restart'] + required['reload'] > 0, required['restart'] > 0 + + def check_db_state(self, is_leader=False): + """ + check the state of database, + returns the next action + 'normal': the database is running and its state is Normal + 'rebuild': the database is running but it is not Noamrl, so it needs to be rebuilded + 'restart': the database is not running, so it needs to be restarted + """ + status, output = self._postgresql.gs_query() + if status != 0: + return 'restart' + local_role = re.findall('local_role +: (.+)', output)[0] + db_state = re.findall('db_state +: (.+)', output)[0] + detail_information = re.findall('detail_information +: (.+)', output)[0] + if is_leader and local_role != 'Primary': + return 'restart' + if local_role == 'Standby' and 'WAL segment removed' in detail_information: + return 'rebuild' + return 'normal' + + @staticmethod + def _remove_file_if_exists(name): + if os.path.isfile(name) or os.path.islink(name): + os.unlink(name) + + @staticmethod + def _pgpass_line(record): + if 'password' in record: + def escape(value): + return re.sub(r'([:\\])', r'\\\1', str(value)) + + record = {n: escape(record.get(n) or '*') for n in ('host', 'port', 'user', 'password')} + return '{host}:{port}:*:{user}:{password}'.format(**record) + + def write_pgpass(self, record): + line = self._pgpass_line(record) + if not line: + return os.environ.copy() + + with open(self._pgpass, 'w') as f: + os.chmod(self._pgpass, stat.S_IWRITE | stat.S_IREAD) + f.write(line) + + env = os.environ.copy() + env['PGPASSFILE'] = self._pgpass + return env + + def write_recovery_conf(self, recovery_params): + if self._postgresql.major_version >= 120000: + if parse_bool(recovery_params.pop('standby_mode', None)): + open(self._standby_signal, 'w').close() + else: + self._remove_file_if_exists(self._standby_signal) + open(self._recovery_signal, 'w').close() + self._recovery_params = recovery_params + else: + with ConfigWriter(self._recovery_conf) as f: + os.chmod(self._recovery_conf, stat.S_IWRITE | stat.S_IREAD) + self._write_recovery_params(f, recovery_params) + + def remove_recovery_conf(self): + for name in (self._recovery_conf, self._standby_signal, self._recovery_signal): + self._remove_file_if_exists(name) + self._recovery_params = {} + + def _sanitize_auto_conf(self): + overwrite = False + lines = [] + + if os.path.exists(self._auto_conf): + try: + with open(self._auto_conf) as f: + for raw_line in f: + line = raw_line.strip() + match = PARAMETER_RE.match(line) + if match and match.group(1).lower() in self._RECOVERY_PARAMETERS: + overwrite = True + else: + lines.append(raw_line) + except Exception: + logger.info('Failed to read %s', self._auto_conf) + + if overwrite: + try: + with open(self._auto_conf, 'w') as f: + for raw_line in lines: + f.write(raw_line) + except Exception: + logger.exception('Failed to remove some unwanted parameters from %s', self._auto_conf) + + def _adjust_recovery_parameters(self): + # It is not strictly necessary, but we can make patroni configs crossi-compatible with all postgres versions. + recovery_conf = {n: v for n, v in self._server_parameters.items() if n.lower() in self._RECOVERY_PARAMETERS} + if recovery_conf: + self._config['recovery_conf'] = recovery_conf + + if self.get('recovery_conf'): + value = self._config['recovery_conf'].pop(self._triggerfile_wrong_name, None) + if self._triggerfile_good_name not in self._config['recovery_conf'] and value: + self._config['recovery_conf'][self._triggerfile_good_name] = value + + def get_server_parameters(self, config): + parameters = config['parameters'].copy() + listen_addresses, port = split_host_port(config['listen'], 5432) + parameters.update(cluster_name=self._postgresql.scope, listen_addresses=listen_addresses, port=str(port)) + if config.get('synchronous_mode', False): + if self._synchronous_standby_names is None: + if config.get('synchronous_mode_strict', False): + parameters['synchronous_standby_names'] = '*' + else: + parameters.pop('synchronous_standby_names', None) + else: + parameters['synchronous_standby_names'] = self._synchronous_standby_names + + # Handle hot_standby <-> replica rename + if parameters.get('wal_level') == ('hot_standby' if self._postgresql.major_version >= 90600 else 'replica'): + parameters['wal_level'] = 'replica' if self._postgresql.major_version >= 90600 else 'hot_standby' + + # Try to recalcualte wal_keep_segments <-> wal_keep_size assuming that typical wal_segment_size is 16MB. + # The real segment size could be estimated from pg_control, but we don't really care, because the only goal of + # this exercise is improving cross version compatibility and user must set the correct parameter in the config. + if self._postgresql.major_version >= 130000: + wal_keep_segments = parameters.pop('wal_keep_segments', self.CMDLINE_OPTIONS['wal_keep_segments'][0]) + parameters.setdefault('wal_keep_size', str(wal_keep_segments * 16) + 'MB') + elif self._postgresql.major_version: + wal_keep_size = parse_int(parameters.pop('wal_keep_size', self.CMDLINE_OPTIONS['wal_keep_size'][0]), 'MB') + parameters.setdefault('wal_keep_segments', int((wal_keep_size + 8) / 16)) + ret = CaseInsensitiveDict({k: v for k, v in parameters.items() if not self._postgresql.major_version or + self._postgresql.major_version >= self.CMDLINE_OPTIONS.get(k, (0, 1, 90100))[2]}) + ret.update({k: os.path.join(self._config_dir, ret[k]) for k in ('hba_file', 'ident_file') if k in ret}) + return ret + + @staticmethod + def _get_unix_local_address(unix_socket_directories): + for d in unix_socket_directories.split(','): + d = d.strip() + if d.startswith('/'): # Only absolute path can be used to connect via unix-socket + return d + return '' + + def _get_tcp_local_address(self): + listen_addresses = self._server_parameters['listen_addresses'].split(',') + + for la in listen_addresses: + if la.strip().lower() in ('*', '0.0.0.0', '127.0.0.1', 'localhost'): # we are listening on '*' or localhost + return 'localhost' # connection via localhost is preferred + return listen_addresses[0].strip() # can't use localhost, take first address from listen_addresses + + @property + def local_connect_kwargs(self): + ret = self._local_address.copy() + # add all of the other connection settings that are available + ret.update(self._superuser) + # if the "username" parameter is present, it actually needs to be "user" + # for connecting to PostgreSQL + if 'username' in self._superuser: + ret['user'] = self._superuser['username'] + del ret['username'] + # ensure certain Patroni configurations are available + ret.update({'database': self._postgresql.database, + 'fallback_application_name': 'Patroni', + 'connect_timeout': 3, + 'options': '-c statement_timeout=2000'}) + return ret + + def resolve_connection_addresses(self): + port = self._server_parameters['port'] + tcp_local_address = self._get_tcp_local_address() + + local_address = {'port': port} + if self._config.get('use_unix_socket'): + unix_socket_directories = self._server_parameters.get('unix_socket_directories') + if unix_socket_directories is not None: + # fallback to tcp if unix_socket_directories is set, but there are no sutable values + local_address['host'] = self._get_unix_local_address(unix_socket_directories) or tcp_local_address + + # if unix_socket_directories is not specified, but use_unix_socket is set to true - do our best + # to use default value, i.e. don't specify a host neither in connection url nor arguments + else: + local_address['host'] = tcp_local_address + + self._local_address = local_address + self.local_replication_address = {'host': tcp_local_address, 'port': port} + + netloc = self._config.get('connect_address') or tcp_local_address + ':' + port + self._postgresql.connection_string = uri('postgres', netloc, self._postgresql.database) + + self._postgresql.set_connection_kwargs(self.local_connect_kwargs) + + def _get_pg_settings(self, names): + return {r[0]: r for r in self._postgresql.query(('SELECT name, setting, unit, vartype, context, sourcefile' + + ' FROM pg_catalog.pg_settings ' + + ' WHERE pg_catalog.lower(name) = ANY(%s)'), + [n.lower() for n in names])} + + @staticmethod + def _handle_wal_buffers(old_values, changes): + wal_block_size = parse_int(old_values['wal_block_size'][1]) + wal_segment_size = old_values['wal_segment_size'] + wal_segment_unit = parse_int(wal_segment_size[2], 'B') if wal_segment_size[2][0].isdigit() else 1 + wal_segment_size = parse_int(wal_segment_size[1]) * wal_segment_unit / wal_block_size + default_wal_buffers = min(max(parse_int(old_values['shared_buffers'][1]) / 32, 8), wal_segment_size) + + wal_buffers = old_values['wal_buffers'] + new_value = str(changes['wal_buffers'] or -1) + + new_value = default_wal_buffers if new_value == '-1' else parse_int(new_value, wal_buffers[2]) + old_value = default_wal_buffers if wal_buffers[1] == '-1' else parse_int(*wal_buffers[1:3]) + + if new_value == old_value: + del changes['wal_buffers'] + + def reload_config(self, config, sighup=False): + self._superuser = config['authentication'].get('superuser', {}) + server_parameters = self.get_server_parameters(config) + + conf_changed = hba_changed = ident_changed = local_connection_address_changed = pending_restart = False + if self._postgresql.state == 'running': + changes = CaseInsensitiveDict({p: v for p, v in server_parameters.items() + if p.lower() not in self._RECOVERY_PARAMETERS}) + changes.update({p: None for p in self._server_parameters.keys() + if not (p in changes or p.lower() in self._RECOVERY_PARAMETERS)}) + if changes: + if 'wal_buffers' in changes: # we need to calculate the default value of wal_buffers + undef = [p for p in ('shared_buffers', 'wal_segment_size', 'wal_block_size') if p not in changes] + changes.update({p: None for p in undef}) + # XXX: query can raise an exception + old_values = self._get_pg_settings(changes.keys()) + if 'wal_buffers' in changes: + self._handle_wal_buffers(old_values, changes) + for p in undef: + del changes[p] + + for r in old_values.values(): + if r[4] != 'internal' and r[0] in changes: + new_value = changes.pop(r[0]) + if new_value is None or not compare_values(r[3], r[2], r[1], new_value): + conf_changed = True + if r[4] == 'postmaster': + pending_restart = True + logger.info('Changed %s from %s to %s (restart might be required)', + r[0], r[1], new_value) + if config.get('use_unix_socket') and r[0] == 'unix_socket_directories'\ + or r[0] in ('listen_addresses', 'port'): + local_connection_address_changed = True + else: + logger.info('Changed %s from %s to %s', r[0], r[1], new_value) + for param, value in changes.items(): + if '.' in param: + # Check that user-defined-paramters have changed (parameters with period in name) + if value is None or param not in self._server_parameters \ + or str(value) != str(self._server_parameters[param]): + logger.info('Changed %s from %s to %s', param, self._server_parameters.get(param), value) + conf_changed = True + elif param in server_parameters: + logger.warning('Removing invalid parameter `%s` from postgresql.parameters', param) + server_parameters.pop(param) + + if (not server_parameters.get('hba_file') or server_parameters['hba_file'] == self._pg_hba_conf) \ + and config.get('pg_hba'): + hba_changed = self._config.get('pg_hba', []) != config['pg_hba'] + + if (not server_parameters.get('ident_file') or server_parameters['ident_file'] == self._pg_hba_conf) \ + and config.get('pg_ident'): + ident_changed = self._config.get('pg_ident', []) != config['pg_ident'] + + self._config = config + self._postgresql.set_pending_restart(pending_restart) + self._server_parameters = server_parameters + self._adjust_recovery_parameters() + self._krbsrvname = config.get('krbsrvname') + + # for not so obvious connection attempts that may happen outside of pyscopg2 + if self._krbsrvname: + os.environ['PGKRBSRVNAME'] = self._krbsrvname + + if not local_connection_address_changed: + self.resolve_connection_addresses() + + if conf_changed: + self.write_postgresql_conf() + + if hba_changed: + self.replace_pg_hba() + + if ident_changed: + self.replace_pg_ident() + + if sighup or conf_changed or hba_changed or ident_changed: + logger.info('Reloading PostgreSQL configuration.') + self._postgresql.reload() + if self._postgresql.major_version >= 90500: + time.sleep(1) + try: + pending_restart = self._postgresql.query('SELECT COUNT(*) FROM pg_catalog.pg_settings' + ' WHERE pending_restart').fetchone()[0] > 0 + self._postgresql.set_pending_restart(pending_restart) + except Exception as e: + logger.warning('Exception %r when running query', e) + else: + logger.info('No PostgreSQL configuration items changed, nothing to reload.') + + def set_synchronous_standby(self, sync_members): + """Sets a node to be synchronous standby and if changed does a reload for PostgreSQL.""" + if sync_members and sync_members != ['*']: + sync_members = [quote_ident(x) for x in sync_members] + if self._postgresql.major_version >= 90600 and len(sync_members) > 1: + sync_param = '{0} ({1})'.format(len(sync_members), ','.join(sync_members)) + else: + sync_param = next(iter(sync_members), None) + if sync_param != self._synchronous_standby_names: + if sync_param is None: + self._server_parameters.pop('synchronous_standby_names', None) + else: + self._server_parameters['synchronous_standby_names'] = sync_param + self._synchronous_standby_names = sync_param + if self._postgresql.state == 'running': + self.write_postgresql_conf() + self._postgresql.reload() + + @property + def effective_configuration(self): + """It might happen that the current value of one (or more) below parameters stored in + the controldata is higher than the value stored in the global cluster configuration. + + Example: max_connections in global configuration is 100, but in controldata + `Current max_connections setting: 200`. If we try to start postgres with + max_connections=100, it will immediately exit. + As a workaround we will start it with the values from controldata and set `pending_restart` + to true as an indicator that current values of parameters are not matching expectations.""" + + if self._postgresql.role == 'master': + return self._server_parameters + + options_mapping = { + 'max_connections': 'max_connections setting', + 'max_prepared_transactions': 'max_prepared_xacts setting', + 'max_locks_per_transaction': 'max_locks_per_xact setting' + } + + if self._postgresql.major_version >= 90400: + options_mapping['max_worker_processes'] = 'max_worker_processes setting' + + if self._postgresql.major_version >= 120000: + options_mapping['max_wal_senders'] = 'max_wal_senders setting' + + data = self._postgresql.controldata() + effective_configuration = self._server_parameters.copy() + + for name, cname in options_mapping.items(): + value = parse_int(effective_configuration[name]) + if cname not in data: + logger.warning('%s is missing from pg_controldata output', cname) + continue + + cvalue = parse_int(data[cname]) + if cvalue > value: + effective_configuration[name] = cvalue + self._postgresql.set_pending_restart(True) + return effective_configuration + + @property + def replication(self): + return self._config['authentication']['replication'] + + @property + def superuser(self): + return self._superuser + + @property + def rewind_credentials(self): + return self._config['authentication'].get('rewind', self._superuser) \ + if self._postgresql.major_version >= 110000 else self._superuser + + @property + def ident_file(self): + ident_file = self._server_parameters.get('ident_file') + return None if ident_file == self._pg_ident_conf else ident_file + + @property + def hba_file(self): + hba_file = self._server_parameters.get('hba_file') + return None if hba_file == self._pg_hba_conf else hba_file + + @property + def pg_hba_conf(self): + return self._pg_hba_conf + + @property + def postgresql_conf(self): + return self._postgresql_conf + + def get(self, key, default=None): + return self._config.get(key, default) diff --git a/patroni-for-openGauss/postgresql/connection.py b/patroni-for-openGauss/postgresql/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..933ee89c16bdbe5586bfce7f8ae712399e78cade --- /dev/null +++ b/patroni-for-openGauss/postgresql/connection.py @@ -0,0 +1,46 @@ +import logging +import psycopg2 + +from contextlib import contextmanager +from threading import Lock + +logger = logging.getLogger(__name__) + + +class Connection(object): + + def __init__(self): + self._lock = Lock() + self._connection = None + self._cursor_holder = None + + def set_conn_kwargs(self, conn_kwargs): + self._conn_kwargs = conn_kwargs + + def get(self): + with self._lock: + if not self._connection or self._connection.closed != 0: + self._connection = psycopg2.connect(**self._conn_kwargs) + self._connection.autocommit = True + self.server_version = self._connection.server_version + return self._connection + + def cursor(self): + if not self._cursor_holder or self._cursor_holder.closed or self._cursor_holder.connection.closed != 0: + logger.info("establishing a new patroni connection to the postgres cluster") + self._cursor_holder = self.get().cursor() + return self._cursor_holder + + def close(self): + if self._connection and self._connection.closed == 0: + self._connection.close() + logger.info("closed patroni connection to the postgresql cluster") + self._cursor_holder = self._connection = None + + +@contextmanager +def get_connection_cursor(**kwargs): + with psycopg2.connect(**kwargs) as conn: + conn.autocommit = True + with conn.cursor() as cur: + yield cur diff --git a/patroni-for-openGauss/postgresql/misc.py b/patroni-for-openGauss/postgresql/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f5caffbd1d1d51ec5764e4a681ccd99f4beb8e2a --- /dev/null +++ b/patroni-for-openGauss/postgresql/misc.py @@ -0,0 +1,70 @@ +import logging + +from patroni.exceptions import PostgresException + +logger = logging.getLogger(__name__) + + +def postgres_version_to_int(pg_version): + """Convert the server_version to integer + + >>> postgres_version_to_int('9.5.3') + 90503 + >>> postgres_version_to_int('9.3.13') + 90313 + >>> postgres_version_to_int('10.1') + 100001 + >>> postgres_version_to_int('10') # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + PostgresException: 'Invalid PostgreSQL version format: X.Y or X.Y.Z is accepted: 10' + >>> postgres_version_to_int('9.6') # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + PostgresException: 'Invalid PostgreSQL version format: X.Y or X.Y.Z is accepted: 9.6' + >>> postgres_version_to_int('a.b.c') # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + PostgresException: 'Invalid PostgreSQL version: a.b.c' + """ + + try: + components = list(map(int, pg_version.split('.'))) + except ValueError: + raise PostgresException('Invalid PostgreSQL version: {0}'.format(pg_version)) + + if len(components) < 2 or len(components) == 2 and components[0] < 10 or len(components) > 3: + raise PostgresException('Invalid PostgreSQL version format: X.Y or X.Y.Z is accepted: {0}'.format(pg_version)) + + if len(components) == 2: + # new style verion numbers, i.e. 10.1 becomes 100001 + components.insert(1, 0) + + return int(''.join('{0:02d}'.format(c) for c in components)) + + +def postgres_major_version_to_int(pg_version): + """ + >>> postgres_major_version_to_int('10') + 100000 + >>> postgres_major_version_to_int('9.6') + 90600 + """ + return postgres_version_to_int(pg_version + '.0') + + +def parse_lsn(lsn): + t = lsn.split('/') + return int(t[0], 16) * 0x100000000 + int(t[1], 16) + + +def parse_history(data): + for line in data.split('\n'): + values = line.strip().split('\t') + if len(values) == 3: + try: + values[0] = int(values[0]) + values[1] = parse_lsn(values[1]) + yield values + except (IndexError, ValueError): + logger.exception('Exception when parsing timeline history line "%s"', values) diff --git a/patroni-for-openGauss/postgresql/postmaster.py b/patroni-for-openGauss/postgresql/postmaster.py new file mode 100644 index 0000000000000000000000000000000000000000..d880d1564a985490e34034ed8f4ad09901025de9 --- /dev/null +++ b/patroni-for-openGauss/postgresql/postmaster.py @@ -0,0 +1,246 @@ +import logging +import multiprocessing +import os +import psutil +import re +import signal +import subprocess +import sys + +from patroni import PATRONI_ENV_PREFIX, KUBERNETES_ENV_PREFIX + +# avoid spawning the resource tracker process +if sys.version_info >= (3, 8): # pragma: no cover + import multiprocessing.resource_tracker + multiprocessing.resource_tracker.getfd = lambda: 0 +elif sys.version_info >= (3, 4): # pragma: no cover + import multiprocessing.semaphore_tracker + multiprocessing.semaphore_tracker.getfd = lambda: 0 + +logger = logging.getLogger(__name__) + +STOP_SIGNALS = { + 'smart': 'TERM', + 'fast': 'INT', + 'immediate': 'QUIT', +} + + +def pg_ctl_start(conn, cmdline, env): + if os.name != 'nt': + os.setsid() + try: + postmaster = subprocess.Popen(cmdline, close_fds=True, env=env) + conn.send(postmaster.pid) + except Exception: + logger.exception('Failed to execute %s', cmdline) + conn.send(None) + conn.close() + + +class PostmasterProcess(psutil.Process): + + def __init__(self, pid): + self.is_single_user = False + if pid < 0: + pid = -pid + self.is_single_user = True + super(PostmasterProcess, self).__init__(pid) + + @staticmethod + def _read_postmaster_pidfile(data_dir): + """Reads and parses postmaster.pid from the data directory + + :returns dictionary of values if successful, empty dictionary otherwise + """ + pid_line_names = ['pid', 'data_dir', 'start_time', 'port', 'socket_dir', 'listen_addr', 'shmem_key'] + try: + with open(os.path.join(data_dir, 'postmaster.pid')) as f: + return {name: line.rstrip('\n') for name, line in zip(pid_line_names, f)} + except IOError: + return {} + + def _is_postmaster_process(self): + try: + start_time = int(self._postmaster_pid.get('start_time', 0)) + if start_time and abs(self.create_time() - start_time) > 3: + logger.info('Process %s is not postmaster, too much difference between PID file start time %s and ' + 'process start time %s', self.pid, self.create_time(), start_time) + return False + except ValueError: + logger.warning('Garbage start time value in pid file: %r', self._postmaster_pid.get('start_time')) + + # Extra safety check. The process can't be ourselves, our parent or our direct child. + if self.pid == os.getpid() or self.pid == os.getppid() or self.ppid() == os.getpid(): + logger.info('Patroni (pid=%s, ppid=%s), "fake postmaster" (pid=%s, ppid=%s)', + os.getpid(), os.getppid(), self.pid, self.ppid()) + return False + + return True + + @classmethod + def _from_pidfile(cls, data_dir): + postmaster_pid = PostmasterProcess._read_postmaster_pidfile(data_dir) + try: + pid = int(postmaster_pid.get('pid', 0)) + if pid: + proc = cls(pid) + proc._postmaster_pid = postmaster_pid + return proc + except ValueError: + pass + + @staticmethod + def from_pidfile(data_dir): + try: + proc = PostmasterProcess._from_pidfile(data_dir) + return proc if proc and proc._is_postmaster_process() else None + except psutil.NoSuchProcess: + return None + + @classmethod + def from_pid(cls, pid): + try: + return cls(pid) + except psutil.NoSuchProcess: + return None + + def signal_kill(self): + """to suspend and kill postmaster and all children + + :returns True if postmaster and children are killed, False if error + """ + try: + self.suspend() + except psutil.NoSuchProcess: + return True + except psutil.Error as e: + logger.warning('Failed to suspend postmaster: %s', e) + + try: + children = self.children(recursive=True) + except psutil.NoSuchProcess: + return True + except psutil.Error as e: + logger.warning('Failed to get a list of postmaster children: %s', e) + children = [] + + try: + self.kill() + except psutil.NoSuchProcess: + return True + except psutil.Error as e: + logger.warning('Could not kill postmaster: %s', e) + return False + + for child in children: + try: + child.kill() + except psutil.Error: + pass + psutil.wait_procs(children + [self]) + return True + + def signal_stop(self, mode, pg_ctl='gs_ctl'): + """Signal postmaster process to stop + + :returns None if signaled, True if process is already gone, False if error + """ + if self.is_single_user: + logger.warning("Cannot stop server; single-user server is running (PID: {0})".format(self.pid)) + return False + if os.name != 'posix': + return self.pg_ctl_kill(mode, pg_ctl) + try: + self.send_signal(getattr(signal, 'SIG' + STOP_SIGNALS[mode])) + except psutil.NoSuchProcess: + return True + except psutil.AccessDenied as e: + logger.warning("Could not send stop signal to PostgreSQL (error: {0})".format(e)) + return False + + return None + + def pg_ctl_kill(self, mode, pg_ctl): + try: + status = subprocess.call([pg_ctl, "kill", STOP_SIGNALS[mode], str(self.pid)]) + except OSError: + return False + if status == 0: + return None + else: + return not self.is_running() + + def wait_for_user_backends_to_close(self): + # These regexps are cross checked against versions PostgreSQL 9.1 .. 11 + aux_proc_re = re.compile("(?:postgres:)( .*:)? (?:(?:archiver|startup|autovacuum launcher|autovacuum worker|" + "checkpointer|logger|stats collector|wal receiver|wal writer|writer)(?: process )?|" + "walreceiver|wal sender process|walsender|walwriter|background writer|" + "logical replication launcher|logical replication worker for|bgworker:) ") + + try: + children = self.children() + except psutil.Error: + return logger.debug('Failed to get list of postmaster children') + + user_backends = [] + user_backends_cmdlines = [] + for child in children: + try: + cmdline = child.cmdline() + if cmdline and not aux_proc_re.match(cmdline[0]): + user_backends.append(child) + user_backends_cmdlines.append(cmdline[0]) + except psutil.NoSuchProcess: + pass + if user_backends: + logger.debug('Waiting for user backends %s to close', ', '.join(user_backends_cmdlines)) + psutil.wait_procs(user_backends) + logger.debug("Backends closed") + + @staticmethod + def start(pgcommand, data_dir, conf, options): + # Unfortunately `pg_ctl start` does not return postmaster pid to us. Without this information + # it is hard to know the current state of postgres startup, so we had to reimplement pg_ctl start + # in python. It will start postgres, wait for port to be open and wait until postgres will start + # accepting connections. + # Important!!! We can't just start postgres using subprocess.Popen, because in this case it + # will be our child for the rest of our live and we will have to take care of it (`waitpid`). + # So we will use the same approach as pg_ctl uses: start a new process, which will start postgres. + # This process will write postmaster pid to stdout and exit immediately. Now it's responsibility + # of init process to take care about postmaster. + # In order to make everything portable we can't use fork&exec approach here, so we will call + # ourselves and pass list of arguments which must be used to start postgres. + # On Windows, in order to run a side-by-side assembly the specified env must include a valid SYSTEMROOT. + env = {p: os.environ[p] for p in os.environ if not p.startswith( + PATRONI_ENV_PREFIX) and not p.startswith(KUBERNETES_ENV_PREFIX)} + try: + proc = PostmasterProcess._from_pidfile(data_dir) + if proc and not proc._is_postmaster_process(): + # Upon start postmaster process performs various safety checks if there is a postmaster.pid + # file in the data directory. Although Patroni already detected that the running process + # corresponding to the postmaster.pid is not a postmaster, the new postmaster might fail + # to start, because it thinks that postmaster.pid is already locked. + # Important!!! Unlink of postmaster.pid isn't an option, because it has a lot of nasty race conditions. + # Luckily there is a workaround to this problem, we can pass the pid from postmaster.pid + # in the `PG_GRANDPARENT_PID` environment variable and postmaster will ignore it. + logger.info("Telling pg_ctl that it is safe to ignore postmaster.pid for process %s", proc.pid) + env['PG_GRANDPARENT_PID'] = str(proc.pid) + except psutil.NoSuchProcess: + pass + cmdline = [pgcommand, '-D', data_dir, '--config-file={}'.format(conf)] + options + logger.debug("Starting postgres: %s", " ".join(cmdline)) + ctx = multiprocessing.get_context('spawn') if sys.version_info >= (3, 4) else multiprocessing + parent_conn, child_conn = ctx.Pipe(False) + proc = ctx.Process(target=pg_ctl_start, args=(child_conn, cmdline, env)) + proc.start() + pid = parent_conn.recv() + proc.join() + if pid is None: + return + logger.info('postmaster pid=%s', pid) + + # TODO: In an extremely unlikely case, the process could have exited and the pid reassigned. The start + # initiation time is not accurate enough to compare to create time as start time would also likely + # be relatively close. We need the subprocess extract pid+start_time in a race free manner. + return PostmasterProcess.from_pid(pid) diff --git a/patroni-for-openGauss/postgresql/rewind.py b/patroni-for-openGauss/postgresql/rewind.py new file mode 100644 index 0000000000000000000000000000000000000000..e8bb4cb1c60e3bbd255cfb3c34cdff8f039b8a2e --- /dev/null +++ b/patroni-for-openGauss/postgresql/rewind.py @@ -0,0 +1,442 @@ +import logging +import os +import shlex +import six +import subprocess + +from threading import Lock, Thread + +from .connection import get_connection_cursor +from .misc import parse_history, parse_lsn +from ..async_executor import CriticalTask +from ..dcs import Leader + +logger = logging.getLogger(__name__) + +REWIND_STATUS = type('Enum', (), {'INITIAL': 0, 'CHECKPOINT': 1, 'CHECK': 2, 'NEED': 3, + 'NOT_NEED': 4, 'SUCCESS': 5, 'FAILED': 6}) + + +def format_lsn(lsn, full=False): + template = '{0:X}/{1:08X}' if full else '{0:X}/{1:X}' + return template.format(lsn >> 32, lsn & 0xFFFFFFFF) + + +class Rewind(object): + + def __init__(self, postgresql): + self._postgresql = postgresql + self._checkpoint_task_lock = Lock() + self.reset_state() + + @staticmethod + def configuration_allows_rewind(data): + return data.get('wal_log_hints setting', 'off') == 'on' or data.get('Data page checksum version', '0') != '0' + + @property + def can_rewind(self): + """ check if pg_rewind executable is there and that pg_controldata indicates + we have either wal_log_hints or checksums turned on + """ + # low-hanging fruit: check if pg_rewind configuration is there + if not self._postgresql.config.get('use_pg_rewind'): + return False + + cmd = [self._postgresql.pgcommand('pg_rewind'), '--help'] + try: + ret = subprocess.call(cmd, stdout=open(os.devnull, 'w'), stderr=subprocess.STDOUT) + if ret != 0: # pg_rewind is not there, close up the shop and go home + return False + except OSError: + return False + return self.configuration_allows_rewind(self._postgresql.controldata()) + + @property + def can_rewind_or_reinitialize_allowed(self): + return self._postgresql.config.get('remove_data_directory_on_diverged_timelines') or self.can_rewind + + def trigger_check_diverged_lsn(self): + if self.can_rewind_or_reinitialize_allowed and self._state != REWIND_STATUS.NEED: + self._state = REWIND_STATUS.CHECK + + @staticmethod + def check_leader_is_not_in_recovery(conn_kwargs): + try: + with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur: + cur.execute('SELECT pg_catalog.pg_is_in_recovery()') + if not cur.fetchone()[0]: + return True + logger.info('Leader is still in_recovery and therefore can\'t be used for rewind') + except Exception: + return logger.exception('Exception when working with leader') + + def _get_checkpoint_end(self, timeline, lsn): + """The checkpoint record size in WAL depends on postgres major version and platform (memory alignment). + Hence, the only reliable way to figure out where it ends, read the record from file with the help of pg_waldump + and parse the output. We are trying to read two records, and expect that it wil fail to read the second one: + `pg_waldump: fatal: error in WAL record at 0/182E220: invalid record length at 0/182E298: wanted 24, got 0` + The error message contains information about LSN of the next record, which is exactly where checkpoint ends.""" + + cmd = self._postgresql.pgcommand('pg_{0}dump'.format(self._postgresql.wal_name)) + lsn8 = format_lsn(lsn, True) + lsn = format_lsn(lsn) + env = os.environ.copy() + env.update(LANG='C', LC_ALL='C', PGDATA=self._postgresql.data_dir) + try: + waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', lsn, '-n', '2'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) + out, err = waldump.communicate() + waldump.wait() + except Exception as e: + logger.error('Failed to execute `%s -t %s -s %s -n 2`: %r', cmd, timeline, lsn, e) + else: + out = out.decode('utf-8').rstrip().split('\n') + err = err.decode('utf-8').rstrip().split('\n') + pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn) + + if len(out) == 1 and len(err) == 1 and ', lsn: {0}, prev '.format(lsn8) in out[0] and pattern in err[0]: + i = err[0].find(pattern) + len(pattern) + j = err[0].find(": wanted ", i) + if j > -1: + try: + return parse_lsn(err[0][i:j]) + except Exception as e: + logger.error('Failed to parse lsn %s: %r', err[0][i:j], e) + logger.error('Failed to parse `%s -t %s -s %s -n 2` output', cmd, timeline, lsn) + logger.error(' stdout=%s', '\n'.join(out)) + logger.error(' stderr=%s', '\n'.join(err)) + + return 0 + + def _get_local_timeline_lsn_from_controldata(self): + in_recovery = timeline = lsn = None + data = self._postgresql.controldata() + try: + if data.get('Database cluster state') == 'shut down in recovery': + in_recovery = True + lsn = data.get('Minimum recovery ending location') + timeline = int(data.get("Min recovery ending loc's timeline")) + if lsn == '0/0' or timeline == 0: # it was a master when it crashed + data['Database cluster state'] = 'shut down' + if data.get('Database cluster state') == 'shut down': + in_recovery = False + lsn = data.get('Latest checkpoint location') + timeline = int(data.get("Latest checkpoint's TimeLineID")) + except (TypeError, ValueError): + logger.exception('Failed to get local timeline and lsn from pg_controldata output') + + if lsn is not None: + try: + lsn = parse_lsn(lsn) + except (IndexError, ValueError) as e: + logger.error('Exception when parsing lsn %s: %r', lsn, e) + lsn = None + + return in_recovery, timeline, lsn + + def _get_local_timeline_lsn(self): + if self._postgresql.is_running(): # if postgres is running - get timeline from replication connection + in_recovery = True + timeline = self._postgresql.received_timeline() or self._postgresql.get_replica_timeline() + lsn = self._postgresql.replayed_location() + else: # otherwise analyze pg_controldata output + in_recovery, timeline, lsn = self._get_local_timeline_lsn_from_controldata() + + log_lsn = format_lsn(lsn) if isinstance(lsn, six.integer_types) else lsn + logger.info('Local timeline=%s lsn=%s', timeline, log_lsn) + return in_recovery, timeline, lsn + + @staticmethod + def _log_master_history(history, i): + start = max(0, i - 3) + end = None if i + 4 >= len(history) else i + 2 + history_show = [] + + def format_history_line(line): + return '{0}\t{1}\t{2}'.format(line[0], format_lsn(line[1]), line[2]) + + for line in history[start:end]: + history_show.append(format_history_line(line)) + + if line != history[-1]: + history_show.append('...') + history_show.append(format_history_line(history[-1])) + + logger.info('master: history=%s', '\n'.join(history_show)) + + def _conn_kwargs(self, member, auth): + ret = member.conn_kwargs(auth) + if not ret.get('database'): + ret['database'] = self._postgresql.database + return ret + + def _check_timeline_and_lsn(self, leader): + in_recovery, local_timeline, local_lsn = self._get_local_timeline_lsn() + if local_timeline is None or local_lsn is None: + return + + if isinstance(leader, Leader): + if leader.member.data.get('role') != 'master': + return + # standby cluster + elif not self.check_leader_is_not_in_recovery(self._conn_kwargs(leader, self._postgresql.config.replication)): + return + + history = need_rewind = None + try: + with self._postgresql.get_replication_connection_cursor(**leader.conn_kwargs()) as cur: + cur.execute('IDENTIFY_SYSTEM') + master_timeline = cur.fetchone()[1] + logger.info('master_timeline=%s', master_timeline) + if local_timeline > master_timeline: # Not always supported by pg_rewind + need_rewind = True + elif local_timeline == master_timeline: + need_rewind = False + elif master_timeline > 1: + cur.execute('TIMELINE_HISTORY %s', (master_timeline,)) + history = bytes(cur.fetchone()[1]).decode('utf-8') + logger.debug('master: history=%s', history) + except Exception: + return logger.exception('Exception when working with master via replication connection') + + if history is not None: + history = list(parse_history(history)) + for i, (parent_timeline, switchpoint, _) in enumerate(history): + if parent_timeline == local_timeline: + # We don't need to rewind when: + # 1. for replica: replayed location is not ahead of switchpoint + # 2. for the former primary: end of checkpoint record is the same as switchpoint + if in_recovery: + need_rewind = local_lsn > switchpoint + elif local_lsn >= switchpoint: + need_rewind = True + else: + need_rewind = switchpoint != self._get_checkpoint_end(local_timeline, local_lsn) + break + elif parent_timeline > local_timeline: + break + self._log_master_history(history, i) + + self._state = need_rewind and REWIND_STATUS.NEED or REWIND_STATUS.NOT_NEED + + def rewind_or_reinitialize_needed_and_possible(self, leader): + if leader and leader.name != self._postgresql.name and leader.conn_url and self._state == REWIND_STATUS.CHECK: + self._check_timeline_and_lsn(leader) + return leader and leader.conn_url and self._state == REWIND_STATUS.NEED + + def __checkpoint(self, task, wakeup): + try: + result = self._postgresql.checkpoint() + except Exception as e: + result = 'Exception: ' + str(e) + with task: + task.complete(not bool(result)) + if task.result: + wakeup() + + def ensure_checkpoint_after_promote(self, wakeup): + """After promote issue a CHECKPOINT from a new thread and asynchronously check the result. + In case if CHECKPOINT failed, just check that timeline in pg_control was updated.""" + + if self._state == REWIND_STATUS.INITIAL and self._postgresql.is_leader(): + with self._checkpoint_task_lock: + if self._checkpoint_task: + with self._checkpoint_task: + if self._checkpoint_task.result: + self._state = REWIND_STATUS.CHECKPOINT + if self._checkpoint_task.result is not False: + return + else: + self._checkpoint_task = CriticalTask() + return Thread(target=self.__checkpoint, args=(self._checkpoint_task, wakeup)).start() + + if self._postgresql.get_master_timeline() == self._postgresql.pg_control_timeline(): + self._state = REWIND_STATUS.CHECKPOINT + + def checkpoint_after_promote(self): + return self._state == REWIND_STATUS.CHECKPOINT + + def _fetch_missing_wal(self, restore_command, wal_filename): + cmd = '' + length = len(restore_command) + i = 0 + while i < length: + if restore_command[i] == '%' and i + 1 < length: + i += 1 + if restore_command[i] == 'p': + cmd += os.path.join(self._postgresql.wal_dir, wal_filename) + elif restore_command[i] == 'f': + cmd += wal_filename + elif restore_command[i] == 'r': + cmd += '000000010000000000000001' + elif restore_command[i] == '%': + cmd += '%' + else: + cmd += '%' + i -= 1 + else: + cmd += restore_command[i] + i += 1 + + logger.info('Trying to fetch the missing wal: %s', cmd) + return self._postgresql.cancellable.call(shlex.split(cmd)) == 0 + + def _find_missing_wal(self, data): + # could not open file "$PGDATA/pg_wal/0000000A00006AA100000068": No such file or directory + pattern = 'could not open file "' + for line in data.decode('utf-8').split('\n'): + b = line.find(pattern) + if b > -1: + b += len(pattern) + e = line.find('": ', b) + if e > -1 and '/' in line[b:e]: + waldir, wal_filename = line[b:e].rsplit('/', 1) + if waldir.endswith('/pg_' + self._postgresql.wal_name) and len(wal_filename) == 24: + return wal_filename + + def pg_rewind(self, r): + # prepare pg_rewind connection + env = self._postgresql.config.write_pgpass(r) + env.update(LANG='C', LC_ALL='C', PGOPTIONS='-c statement_timeout=0') + dsn = self._postgresql.config.format_dsn(r, True) + logger.info('running pg_rewind from %s', dsn) + + restore_command = self._postgresql.config.get('recovery_conf', {}).get('restore_command') \ + if self._postgresql.major_version < 120000 else self._postgresql.get_guc_value('restore_command') + + cmd = [self._postgresql.pgcommand('pg_rewind')] + if self._postgresql.major_version >= 130000 and restore_command: + cmd.append('--restore-target-wal') + cmd.extend(['-D', self._postgresql.data_dir, '--source-server', dsn]) + + while True: + results = {} + ret = self._postgresql.cancellable.call(cmd, env=env, communicate=results) + + logger.info('pg_rewind exit code=%s', ret) + if ret is None: + return False + + logger.info(' stdout=%s', results['stdout'].decode('utf-8')) + logger.info(' stderr=%s', results['stderr'].decode('utf-8')) + if ret == 0: + return True + + if not restore_command or self._postgresql.major_version >= 130000: + return False + + missing_wal = self._find_missing_wal(results['stderr']) or self._find_missing_wal(results['stdout']) + if not missing_wal: + return False + + if not self._fetch_missing_wal(restore_command, missing_wal): + logger.info('Failed to fetch WAL segment %s required for pg_rewind', missing_wal) + return False + + def execute(self, leader): + if self._postgresql.is_running() and not self._postgresql.stop(checkpoint=False): + return logger.warning('Can not run pg_rewind because postgres is still running') + + # prepare pg_rewind connection + r = self._conn_kwargs(leader, self._postgresql.config.rewind_credentials) + + # 1. make sure that we are really trying to rewind from the master + # 2. make sure that pg_control contains the new timeline by: + # running a checkpoint or + # waiting until Patroni on the master will expose checkpoint_after_promote=True + checkpoint_status = leader.checkpoint_after_promote if isinstance(leader, Leader) else None + if checkpoint_status is None: # master still runs the old Patroni + leader_status = self._postgresql.checkpoint(self._conn_kwargs(leader, self._postgresql.config.superuser)) + if leader_status: + return logger.warning('Can not use %s for rewind: %s', leader.name, leader_status) + elif not checkpoint_status: + return logger.info('Waiting for checkpoint on %s before rewind', leader.name) + elif not self.check_leader_is_not_in_recovery(r): + return + + if self.pg_rewind(r): + self._state = REWIND_STATUS.SUCCESS + elif not self.check_leader_is_not_in_recovery(r): + logger.warning('Failed to rewind because master %s become unreachable', leader.name) + else: + logger.error('Failed to rewind from healty master: %s', leader.name) + + for name in ('remove_data_directory_on_rewind_failure', 'remove_data_directory_on_diverged_timelines'): + if self._postgresql.config.get(name): + logger.warning('%s is set. removing...', name) + self._postgresql.remove_data_directory() + self._state = REWIND_STATUS.INITIAL + break + else: + self._state = REWIND_STATUS.FAILED + return False + + def reset_state(self): + self._state = REWIND_STATUS.INITIAL + with self._checkpoint_task_lock: + self._checkpoint_task = None + + @property + def is_needed(self): + return self._state in (REWIND_STATUS.CHECK, REWIND_STATUS.NEED) + + @property + def executed(self): + return self._state > REWIND_STATUS.NOT_NEED + + @property + def failed(self): + return self._state == REWIND_STATUS.FAILED + + def read_postmaster_opts(self): + """returns the list of option names/values from postgres.opts, Empty dict if read failed or no file""" + result = {} + try: + with open(os.path.join(self._postgresql.data_dir, 'postmaster.opts')) as f: + data = f.read() + for opt in data.split('" "'): + if '=' in opt and opt.startswith('--'): + name, val = opt.split('=', 1) + result[name.strip('-')] = val.rstrip('"\n') + except IOError: + logger.exception('Error when reading postmaster.opts') + return result + + def single_user_mode(self, communicate=None, options=None): + """run a given command in a single-user mode. If the command is empty - then just start and stop""" + cmd = [self._postgresql.pgcommand('gaussdb'), '--single', '-D', self._postgresql.data_dir] + for opt, val in sorted((options or {}).items()): + cmd.extend(['-c', '{0}={1}'.format(opt, val)]) + # need a database name to connect + cmd.append('template1') + return self._postgresql.cancellable.call(cmd, communicate=communicate) + + def cleanup_archive_status(self): + status_dir = os.path.join(self._postgresql.wal_dir, 'archive_status') + try: + for f in os.listdir(status_dir): + path = os.path.join(status_dir, f) + try: + if os.path.islink(path): + os.unlink(path) + elif os.path.isfile(path): + os.remove(path) + except OSError: + logger.exception('Unable to remove %s', path) + except OSError: + logger.exception('Unable to list %s', status_dir) + + def ensure_clean_shutdown(self): + self.cleanup_archive_status() + + # Start in a single user mode and stop to produce a clean shutdown + opts = self.read_postmaster_opts() + opts.update({'archive_mode': 'on', 'archive_command': 'false'}) + self._postgresql.config.remove_recovery_conf() + output = {} + ret = self.single_user_mode(communicate=output, options=opts) + if ret != 0: + logger.error('Crash recovery finished with code=%s', ret) + logger.info(' stdout=%s', output['stdout'].decode('utf-8')) + logger.info(' stderr=%s', output['stderr'].decode('utf-8')) + return ret == 0 or None diff --git a/patroni-for-openGauss/postgresql/slots.py b/patroni-for-openGauss/postgresql/slots.py new file mode 100644 index 0000000000000000000000000000000000000000..3738a25da51efbe73fef1e7c3ef7b4e4ff3872a5 --- /dev/null +++ b/patroni-for-openGauss/postgresql/slots.py @@ -0,0 +1,111 @@ +import logging + +from patroni.postgresql.connection import get_connection_cursor +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +def compare_slots(s1, s2): + return s1['type'] == s2['type'] and (s1['type'] == 'physical' or + s1['database'] == s2['database'] and s1['plugin'] == s2['plugin']) + + +class SlotsHandler(object): + + def __init__(self, postgresql): + self._postgresql = postgresql + self._replication_slots = {} # already existing replication slots + self.schedule() + + def _query(self, sql, *params): + return self._postgresql.query(sql, *params, retry=False) + + def load_replication_slots(self): + if self._postgresql.major_version >= 90400 and self._schedule_load_slots: + replication_slots = {} + cursor = self._query('SELECT slot_name, slot_type, plugin, database FROM pg_catalog.pg_replication_slots') + for r in cursor: + value = {'type': r[1]} + if r[1] == 'logical': + value.update({'plugin': r[2], 'database': r[3]}) + replication_slots[r[0]] = value + self._replication_slots = replication_slots + self._schedule_load_slots = False + + def ignore_replication_slot(self, cluster, name): + slot = self._replication_slots[name] + for matcher in cluster.config.ignore_slots_matchers: + if ((matcher.get("name") is None or matcher["name"] == name) + and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))): + return True + return False + + def drop_replication_slot(self, name): + cursor = self._query(('SELECT pg_catalog.pg_drop_replication_slot(%s) WHERE EXISTS (SELECT 1 ' + + 'FROM pg_catalog.pg_replication_slots WHERE slot_name = %s AND NOT active)'), name, name) + # In normal situation rowcount should be 1, otherwise either slot doesn't exists or it is still active + return cursor.rowcount == 1 + + def sync_replication_slots(self, cluster): + if self._postgresql.major_version >= 90400 and cluster.config: + try: + self.load_replication_slots() + + slots = cluster.get_replication_slots(self._postgresql.name, self._postgresql.role) + + # drop old replication slots which are not presented in desired slots + for name in set(self._replication_slots) - set(slots): + if not self.ignore_replication_slot(cluster, name) and not self.drop_replication_slot(name): + logger.error("Failed to drop replication slot '%s'", name) + self._schedule_load_slots = True + + immediately_reserve = ', true' if self._postgresql.major_version >= 90600 else '' + + logical_slots = defaultdict(dict) + for name, value in slots.items(): + if name in self._replication_slots and not compare_slots(value, self._replication_slots[name]): + logger.info("Trying to drop replication slot '%s' because value is changing from %s to %s", + name, self._replication_slots[name], value) + if not self.drop_replication_slot(name): + logger.error("Failed to drop replication slot '%s'", name) + self._schedule_load_slots = True + continue + self._replication_slots.pop(name) + if name not in self._replication_slots: + if value['type'] == 'physical': + try: + self._query(("SELECT pg_catalog.pg_create_physical_replication_slot(%s{0})" + + " WHERE NOT EXISTS (SELECT 1 FROM pg_catalog.pg_replication_slots" + + " WHERE slot_type = 'physical' AND slot_name = %s)").format( + immediately_reserve), name, name) + except Exception: + logger.exception("Failed to create physical replication slot '%s'", name) + self._schedule_load_slots = True + elif value['type'] == 'logical' and name not in self._replication_slots: + logical_slots[value['database']][name] = value + + # create new logical slots + for database, values in logical_slots.items(): + conn_kwargs = self._postgresql.config.local_connect_kwargs + conn_kwargs['database'] = database + with get_connection_cursor(**conn_kwargs) as cur: + for name, value in values.items(): + try: + cur.execute("SELECT pg_catalog.pg_create_logical_replication_slot(%s, %s)" + + " WHERE NOT EXISTS (SELECT 1 FROM pg_catalog.pg_replication_slots" + + " WHERE slot_type = 'logical' AND slot_name = %s)", + (name, value['plugin'], name)) + except Exception: + logger.exception("Failed to create logical replication slot '%s' plugin='%s'", + name, value['plugin']) + self._schedule_load_slots = True + self._replication_slots = slots + except Exception: + logger.exception('Exception when changing replication slots') + self._schedule_load_slots = True + + def schedule(self, value=None): + if value is None: + value = self._postgresql.major_version >= 90400 + self._schedule_load_slots = value diff --git a/patroni-for-openGauss/postgresql/validator.py b/patroni-for-openGauss/postgresql/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0cad99722fe1d6446f5d70d8e682bf44dbc557f --- /dev/null +++ b/patroni-for-openGauss/postgresql/validator.py @@ -0,0 +1,493 @@ +import abc +import logging +import six + +from collections import namedtuple +from urllib3.response import HTTPHeaderDict + +from ..utils import parse_bool, parse_int, parse_real + +logger = logging.getLogger(__name__) + + +class CaseInsensitiveDict(HTTPHeaderDict): + + def add(self, key, val): + self[key] = val + + def __getitem__(self, key): + return self._container[key.lower()][1] + + def __repr__(self): + return str(dict(self.items())) + + def copy(self): + return CaseInsensitiveDict(self._container.values()) + + +class Bool(namedtuple('Bool', 'version_from,version_till')): + + @staticmethod + def transform(name, value): + if parse_bool(value) is not None: + return value + logger.warning('Removing bool parameter=%s from the config due to the invalid value=%s', name, value) + + +@six.add_metaclass(abc.ABCMeta) +class Number(namedtuple('Number', 'version_from,version_till,min_val,max_val,unit')): + + @staticmethod + @abc.abstractmethod + def parse(value, unit): + """parse value""" + + def transform(self, name, value): + num_value = self.parse(value, self.unit) + if num_value is not None: + if num_value < self.min_val: + logger.warning('Value=%s of parameter=%s is too low, increasing to %s%s', + value, name, self.min_val, self.unit or '') + return self.min_val + if num_value > self.max_val: + logger.warning('Value=%s of parameter=%s is too big, decreasing to %s%s', + value, name, self.max_val, self.unit or '') + return self.max_val + return value + logger.warning('Removing %s parameter=%s from the config due to the invalid value=%s', + self.__class__.__name__.lower(), name, value) + + +class Integer(Number): + + @staticmethod + def parse(value, unit): + return parse_int(value, unit) + + +class Real(Number): + + @staticmethod + def parse(value, unit): + return parse_real(value, unit) + + +class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')): + + def transform(self, name, value): + if str(value).lower() in self.possible_values: + return value + logger.warning('Removing enum parameter=%s from the config due to the invalid value=%s', name, value) + + +class EnumBool(Enum): + + def transform(self, name, value): + if parse_bool(value) is not None: + return value + return super(EnumBool, self).transform(name, value) + + +class String(namedtuple('String', 'version_from,version_till')): + + @staticmethod + def transform(name, value): + return value + + +# Format: +# key - parameter name +# value - tuple or multiple tuples if something was changing in GUC across postgres versions +parameters = CaseInsensitiveDict({ + 'allow_system_table_mods': Bool(90200, None), + 'application_name': String(90200, None), + 'archive_command': String(90200, None), + 'archive_mode': ( + Bool(90200, 90500), + EnumBool(90500, None, ('always',)) + ), + 'archive_timeout': Integer(90200, None, 0, 1073741823, 's'), + 'array_nulls': Bool(90200, None), + 'authentication_timeout': Integer(90200, None, 1, 600, 's'), + 'autovacuum': Bool(90200, None), + 'autovacuum_analyze_scale_factor': Real(90200, None, 0, 100, None), + 'autovacuum_analyze_threshold': Integer(90200, None, 0, 2147483647, None), + 'autovacuum_freeze_max_age': Integer(90200, None, 100000, 2000000000, None), + 'autovacuum_max_workers': Integer(90600, None, 1, 262143, None), + 'autovacuum_multixact_freeze_max_age': Integer(90200, None, 10000, 2000000000, None), + 'autovacuum_naptime': Integer(90200, None, 1, 2147483, 's'), + 'autovacuum_vacuum_cost_delay': ( + Integer(90200, 120000, -1, 100, 'ms'), + Real(120000, None, -1, 100, 'ms') + ), + 'autovacuum_vacuum_cost_limit': Integer(90200, None, -1, 10000, None), + 'autovacuum_vacuum_insert_scale_factor': Real(130000, None, 0, 100, None), + 'autovacuum_vacuum_insert_threshold': Integer(130000, None, -1, 2147483647, None), + 'autovacuum_vacuum_scale_factor': Real(90200, None, 0, 100, None), + 'autovacuum_vacuum_threshold': Integer(90200, None, 0, 2147483647, None), + 'autovacuum_work_mem': Integer(90400, None, -1, 2147483647, 'kB'), + 'backend_flush_after': Integer(90600, None, 0, 256, '8kB'), + 'backslash_quote': EnumBool(90200, None, ('safe_encoding',)), + 'backtrace_functions': String(130000, None), + 'bgwriter_delay': Integer(90200, None, 10, 10000, 'ms'), + 'bgwriter_flush_after': Integer(90600, None, 0, 256, '8kB'), + 'bgwriter_lru_maxpages': ( + Integer(90200, 100000, 0, 1000, None), + Integer(100000, None, 0, 1073741823, None) + ), + 'bgwriter_lru_multiplier': Real(90200, None, 0, 10, None), + 'bonjour': Bool(90200, None), + 'bonjour_name': String(90200, None), + 'bytea_output': Enum(90200, None, ('escape', 'hex')), + 'check_function_bodies': Bool(90200, None), + 'checkpoint_completion_target': Real(90200, None, 0, 1, None), + 'checkpoint_flush_after': Integer(90600, None, 0, 256, '8kB'), + 'checkpoint_segments': Integer(90200, 90500, 1, 2147483647, None), + 'checkpoint_timeout': ( + Integer(90200, 90600, 30, 3600, 's'), + Integer(90600, None, 30, 86400, 's') + ), + 'checkpoint_warning': Integer(90200, None, 0, 2147483647, 's'), + 'client_encoding': String(90200, None), + 'client_min_messages': Enum(90200, None, ('debug5', 'debug4', 'debug3', 'debug2', + 'debug1', 'log', 'notice', 'warning', 'error')), + 'cluster_name': String(90500, None), + 'commit_delay': Integer(90200, None, 0, 100000, None), + 'commit_siblings': Integer(90200, None, 0, 1000, None), + 'config_file': String(90200, None), + 'constraint_exclusion': EnumBool(90200, None, ('partition',)), + 'cpu_index_tuple_cost': Real(90200, None, 0, 1.79769e+308, None), + 'cpu_operator_cost': Real(90200, None, 0, 1.79769e+308, None), + 'cpu_tuple_cost': Real(90200, None, 0, 1.79769e+308, None), + 'cursor_tuple_fraction': Real(90200, None, 0, 1, None), + 'data_directory': String(90200, None), + 'data_sync_retry': Bool(90200, None), + 'DateStyle': String(90200, None), + 'db_user_namespace': Bool(90200, None), + 'deadlock_timeout': Integer(90200, None, 1, 2147483647, 'ms'), + 'debug_pretty_print': Bool(90200, None), + 'debug_print_parse': Bool(90200, None), + 'debug_print_plan': Bool(90200, None), + 'debug_print_rewritten': Bool(90200, None), + 'default_statistics_target': Integer(90200, None, 1, 10000, None), + 'default_table_access_method': String(120000, None), + 'default_tablespace': String(90200, None), + 'default_text_search_config': String(90200, None), + 'default_transaction_deferrable': Bool(90200, None), + 'default_transaction_isolation': Enum(90200, None, ('serializable', 'repeatable read', + 'read committed', 'read uncommitted')), + 'default_transaction_read_only': Bool(90200, None), + 'default_with_oids': Bool(90200, 120000), + 'dynamic_library_path': String(90200, None), + 'dynamic_shared_memory_type': ( + Enum(90400, 120000, ('posix', 'sysv', 'mmap', 'none')), + Enum(120000, None, ('posix', 'sysv', 'mmap')) + ), + 'effective_cache_size': Integer(90200, None, 1, 2147483647, '8kB'), + 'effective_io_concurrency': Integer(90200, None, 0, 1000, None), + 'enable_bitmapscan': Bool(90200, None), + 'enable_gathermerge': Bool(100000, None), + 'enable_hashagg': Bool(90200, None), + 'enable_hashjoin': Bool(90200, None), + 'enable_incremental_sort': Bool(130000, None), + 'enable_indexonlyscan': Bool(90200, None), + 'enable_indexscan': Bool(90200, None), + 'enable_material': Bool(90200, None), + 'enable_mergejoin': Bool(90200, None), + 'enable_nestloop': Bool(90200, None), + 'enable_parallel_append': Bool(110000, None), + 'enable_parallel_hash': Bool(110000, None), + 'enable_partition_pruning': Bool(110000, None), + 'enable_partitionwise_aggregate': Bool(110000, None), + 'enable_partitionwise_join': Bool(110000, None), + 'enable_seqscan': Bool(90200, None), + 'enable_sort': Bool(90200, None), + 'enable_tidscan': Bool(90200, None), + 'escape_string_warning': Bool(90200, None), + 'event_source': String(90200, None), + 'exit_on_error': Bool(90200, None), + 'external_pid_file': String(90200, None), + 'extra_float_digits': Integer(90200, None, -15, 3, None), + 'force_parallel_mode': EnumBool(90600, None, ('regress',)), + 'from_collapse_limit': Integer(90200, None, 1, 2147483647, None), + 'fsync': Bool(90200, None), + 'full_page_writes': Bool(90200, None), + 'geqo': Bool(90200, None), + 'geqo_effort': Integer(90200, None, 1, 10, None), + 'geqo_generations': Integer(90200, None, 0, 2147483647, None), + 'geqo_pool_size': Integer(90200, None, 0, 2147483647, None), + 'geqo_seed': Real(90200, None, 0, 1, None), + 'geqo_selection_bias': Real(90200, None, 1.5, 2, None), + 'geqo_threshold': Integer(90200, None, 2, 2147483647, None), + 'gin_fuzzy_search_limit': Integer(90200, None, 0, 2147483647, None), + 'gin_pending_list_limit': Integer(90500, None, 64, 2147483647, 'kB'), + 'hash_mem_multiplier': Real(130000, None, 1, 1000, None), + 'hba_file': String(90200, None), + 'hot_standby': Bool(90200, None), + 'hot_standby_feedback': Bool(90200, None), + 'huge_pages': EnumBool(90400, None, ('try',)), + 'ident_file': String(90200, None), + 'idle_in_transaction_session_timeout': Integer(90600, None, 0, 2147483647, 'ms'), + 'ignore_checksum_failure': Bool(90200, None), + 'ignore_invalid_pages': Bool(130000, None), + 'ignore_system_indexes': Bool(90200, None), + 'IntervalStyle': Enum(90200, None, ('postgres', 'postgres_verbose', 'sql_standard', 'iso_8601')), + 'jit': Bool(110000, None), + 'jit_above_cost': Real(110000, None, -1, 1.79769e+308, None), + 'jit_debugging_support': Bool(110000, None), + 'jit_dump_bitcode': Bool(110000, None), + 'jit_expressions': Bool(110000, None), + 'jit_inline_above_cost': Real(110000, None, -1, 1.79769e+308, None), + 'jit_optimize_above_cost': Real(110000, None, -1, 1.79769e+308, None), + 'jit_profiling_support': Bool(110000, None), + 'jit_provider': String(110000, None), + 'jit_tuple_deforming': Bool(110000, None), + 'join_collapse_limit': Integer(90200, None, 1, 2147483647, None), + 'krb_caseins_users': Bool(90200, None), + 'krb_server_keyfile': String(90200, None), + 'krb_srvname': String(90200, 90400), + 'lc_messages': String(90200, None), + 'lc_monetary': String(90200, None), + 'lc_numeric': String(90200, None), + 'lc_time': String(90200, None), + 'listen_addresses': String(90200, None), + 'local_preload_libraries': String(90200, None), + 'lock_timeout': Integer(90200, None, 0, 2147483647, 'ms'), + 'lo_compat_privileges': Bool(90200, None), + 'log_autovacuum_min_duration': Integer(90200, None, -1, 2147483647, 'ms'), + 'log_checkpoints': Bool(90200, None), + 'log_connections': Bool(90200, None), + 'log_destination': String(90200, None), + 'log_directory': String(90200, None), + 'log_disconnections': Bool(90200, None), + 'log_duration': Bool(90200, None), + 'log_error_verbosity': Enum(90200, None, ('terse', 'default', 'verbose')), + 'log_executor_stats': Bool(90200, None), + 'log_file_mode': Integer(90200, None, 0, 511, None), + 'log_filename': String(90200, None), + 'logging_collector': Bool(90200, None), + 'log_hostname': Bool(90200, None), + 'logical_decoding_work_mem': Integer(130000, None, 64, 2147483647, 'kB'), + 'log_line_prefix': String(90200, None), + 'log_lock_waits': Bool(90200, None), + 'log_min_duration_sample': Integer(130000, None, -1, 2147483647, 'ms'), + 'log_min_duration_statement': Integer(90200, None, -1, 2147483647, 'ms'), + 'log_min_error_statement': Enum(90200, None, ('debug5', 'debug4', 'debug3', 'debug2', 'debug1', 'info', + 'notice', 'warning', 'error', 'log', 'fatal', 'panic')), + 'log_min_messages': Enum(90200, None, ('debug5', 'debug4', 'debug3', 'debug2', 'debug1', 'info', + 'notice', 'warning', 'error', 'log', 'fatal', 'panic')), + 'log_parameter_max_length': Integer(130000, None, -1, 1073741823, 'B'), + 'log_parameter_max_length_on_error': Integer(130000, None, -1, 1073741823, 'B'), + 'log_parser_stats': Bool(90200, None), + 'log_planner_stats': Bool(90200, None), + 'log_replication_commands': Bool(90500, None), + 'log_rotation_age': Integer(90200, None, 0, 35791394, 'min'), + 'log_rotation_size': Integer(90200, None, 0, 2097151, 'kB'), + 'log_statement': Enum(90200, None, ('none', 'ddl', 'mod', 'all')), + 'log_statement_sample_rate': Real(130000, None, 0, 1, None), + 'log_statement_stats': Bool(90200, None), + 'log_temp_files': Integer(90200, None, -1, 2147483647, 'kB'), + 'log_timezone': String(90200, None), + 'log_transaction_sample_rate': Real(120000, None, 0, 1, None), + 'log_truncate_on_rotation': Bool(90200, None), + 'maintenance_io_concurrency': Integer(130000, None, 0, 1000, None), + 'maintenance_work_mem': Integer(90200, None, 1024, 2147483647, 'kB'), + 'max_connections': Integer(90600, None, 1, 262143, None), + 'max_files_per_process': ( + Integer(90200, 130000, 25, 2147483647, None), + Integer(130000, None, 64, 2147483647, None) + ), + 'max_locks_per_transaction': Integer(90200, None, 10, 2147483647, None), + 'max_logical_replication_workers': Integer(100000, None, 0, 262143, None), + 'max_parallel_maintenance_workers': Integer(110000, None, 0, 1024, None), + 'max_parallel_workers': Integer(100000, None, 0, 1024, None), + 'max_parallel_workers_per_gather': Integer(90600, None, 0, 1024, None), + 'max_pred_locks_per_page': Integer(100000, None, 0, 2147483647, None), + 'max_pred_locks_per_relation': Integer(100000, None, -2147483648, 2147483647, None), + 'max_pred_locks_per_transaction': Integer(90200, None, 10, 2147483647, None), + 'max_prepared_transactions': Integer(90200, None, 0, 536870911, None), + 'max_replication_slots': Integer(90200, None, 8, 100, None), + 'max_slot_wal_keep_size': Integer(130000, None, -1, 2147483647, 'MB'), + 'max_stack_depth': Integer(90200, None, 100, 2147483647, 'kB'), + 'max_standby_archive_delay': Integer(90200, None, -1, 2147483647, 'ms'), + 'max_standby_streaming_delay': Integer(90200, None, -1, 2147483647, 'ms'), + 'max_sync_workers_per_subscription': Integer(100000, None, 0, 262143, None), + 'max_wal_senders': Integer(90200, None, 0, 262143, None), + 'max_wal_size': ( + Integer(90500, 100000, 2, 2147483647, '16MB'), + Integer(100000, None, 2, 2147483647, 'MB') + ), + 'max_worker_processes': ( + Integer(90400, 90600, 1, 8388607, None), + Integer(90600, None, 0, 262143, None) + ), + 'min_parallel_index_scan_size': Integer(100000, None, 0, 715827882, '8kB'), + 'min_parallel_relation_size': Integer(90600, 100000, 0, 715827882, '8kB'), + 'min_parallel_table_scan_size': Integer(100000, None, 0, 715827882, '8kB'), + 'min_wal_size': ( + Integer(90500, 100000, 2, 2147483647, '16MB'), + Integer(100000, None, 2, 2147483647, 'MB') + ), + 'old_snapshot_threshold': Integer(90600, None, -1, 86400, 'min'), + 'operator_precedence_warning': Bool(90500, None), + 'parallel_leader_participation': Bool(110000, None), + 'parallel_setup_cost': Real(90600, None, 0, 1.79769e+308, None), + 'parallel_tuple_cost': Real(90600, None, 0, 1.79769e+308, None), + 'password_encryption': ( + Bool(90200, 100000), + Enum(100000, None, ('md5', 'scram-sha-256')) + ), + 'plan_cache_mode': Enum(120000, None, ('auto', 'force_generic_plan', 'force_custom_plan')), + 'port': Integer(90200, None, 1, 65535, None), + 'post_auth_delay': Integer(90200, None, 0, 2147, 's'), + 'pre_auth_delay': Integer(90200, None, 0, 60, 's'), + 'quote_all_identifiers': Bool(90200, None), + 'random_page_cost': Real(90200, None, 0, 1.79769e+308, None), + 'replacement_sort_tuples': Integer(90600, 110000, 0, 2147483647, None), + 'restart_after_crash': Bool(90200, None), + 'row_security': Bool(90500, None), + 'search_path': String(90200, None), + 'seq_page_cost': Real(90200, None, 0, 1.79769e+308, None), + 'session_preload_libraries': String(90400, None), + 'session_replication_role': Enum(90200, None, ('origin', 'replica', 'local')), + 'shared_buffers': Integer(90200, None, 16, 1073741823, '8kB'), + 'shared_memory_type': Enum(120000, None, ('sysv', 'mmap')), + 'shared_preload_libraries': String(90200, None), + 'sql_inheritance': Bool(90200, 100000), + 'ssl': Bool(90200, None), + 'ssl_ca_file': String(90200, None), + 'ssl_cert_file': String(90200, None), + 'ssl_ciphers': String(90200, None), + 'ssl_crl_file': String(90200, None), + 'ssl_dh_params_file': String(100000, None), + 'ssl_ecdh_curve': String(90400, None), + 'ssl_key_file': String(90200, None), + 'ssl_max_protocol_version': Enum(120000, None, ('', 'tlsv1', 'tlsv1.1', 'tlsv1.2', 'tlsv1.3')), + 'ssl_min_protocol_version': Enum(120000, None, ('tlsv1', 'tlsv1.1', 'tlsv1.2', 'tlsv1.3')), + 'ssl_passphrase_command': String(110000, None), + 'ssl_passphrase_command_supports_reload': Bool(110000, None), + 'ssl_prefer_server_ciphers': Bool(90400, None), + 'ssl_renegotiation_limit': Integer(90200, 90500, 0, 2147483647, 'kB'), + 'standard_conforming_strings': Bool(90200, None), + 'statement_timeout': Integer(90200, None, 0, 2147483647, 'ms'), + 'stats_temp_directory': String(90200, None), + 'superuser_reserved_connections': Integer(90600, None, 0, 262143, None), + 'synchronize_seqscans': Bool(90200, None), + 'synchronous_commit': EnumBool(90200, None, ('on', 'off', 'local', 'remote_receive', 'remote_apply')), + 'synchronous_standby_names': String(90200, None), + 'syslog_facility': Enum(90200, None, ('local0', 'local1', 'local2', 'local3', + 'local4', 'local5', 'local6', 'local7')), + 'syslog_ident': String(90200, None), + 'syslog_sequence_numbers': Bool(90600, None), + 'syslog_split_messages': Bool(90600, None), + 'tcp_keepalives_count': Integer(90200, None, 0, 2147483647, None), + 'tcp_keepalives_idle': Integer(90200, None, 0, 2147483647, 's'), + 'tcp_keepalives_interval': Integer(90200, None, 0, 2147483647, 's'), + 'tcp_user_timeout': Integer(120000, None, 0, 2147483647, 'ms'), + 'temp_buffers': Integer(90200, None, 100, 1073741823, '8kB'), + 'temp_file_limit': Integer(90200, None, -1, 2147483647, 'kB'), + 'temp_tablespaces': String(90200, None), + 'TimeZone': String(90200, None), + 'timezone_abbreviations': String(90200, None), + 'trace_notify': Bool(90200, None), + 'trace_recovery_messages': Enum(90200, None, ('debug5', 'debug4', 'debug3', 'debug2', + 'debug1', 'log', 'notice', 'warning', 'error')), + 'trace_sort': Bool(90200, None), + 'track_activities': Bool(90200, None), + 'track_activity_query_size': ( + Integer(90200, 110000, 100, 102400, None), + Integer(110000, 130000, 100, 102400, 'B'), + Integer(130000, None, 100, 1048576, 'B') + ), + 'track_commit_timestamp': Bool(90500, None), + 'track_counts': Bool(90200, None), + 'track_functions': Enum(90200, None, ('none', 'pl', 'all')), + 'track_io_timing': Bool(90200, None), + 'transaction_deferrable': Bool(90200, None), + 'transaction_isolation': Enum(90200, None, ('serializable', 'repeatable read', + 'read committed', 'read uncommitted')), + 'transaction_read_only': Bool(90200, None), + 'transform_null_equals': Bool(90200, None), + 'unix_socket_directories': String(90200, None), + 'unix_socket_group': String(90200, None), + 'unix_socket_permissions': Integer(90200, None, 0, 511, None), + 'update_process_title': Bool(90200, None), + 'vacuum_cleanup_index_scale_factor': Real(110000, None, 0, 1e+10, None), + 'vacuum_cost_delay': ( + Integer(90200, 120000, 0, 100, 'ms'), + Real(120000, None, 0, 100, 'ms') + ), + 'vacuum_cost_limit': Integer(90200, None, 1, 10000, None), + 'vacuum_cost_page_dirty': Integer(90200, None, 0, 10000, None), + 'vacuum_cost_page_hit': Integer(90200, None, 0, 10000, None), + 'vacuum_cost_page_miss': Integer(90200, None, 0, 10000, None), + 'vacuum_defer_cleanup_age': Integer(90200, None, 0, 1000000, None), + 'vacuum_freeze_min_age': Integer(90200, None, 0, 1000000000, None), + 'vacuum_freeze_table_age': Integer(90200, None, 0, 2000000000, None), + 'vacuum_multixact_freeze_min_age': Integer(90200, None, 0, 1000000000, None), + 'vacuum_multixact_freeze_table_age': Integer(90200, None, 0, 2000000000, None), + 'wal_buffers': Integer(90200, None, -1, 262143, '8kB'), + 'wal_compression': Bool(90500, None), + 'wal_consistency_checking': String(100000, None), + 'wal_init_zero': Bool(120000, None), + 'wal_keep_segments': Integer(90200, 130000, 0, 2147483647, None), + 'wal_keep_size': Integer(130000, None, 0, 2147483647, 'MB'), + 'wal_level': Enum(90200, None, ('minimal', 'archive', 'hot_standby', 'logical')), + 'wal_log_hints': Bool(90200, None), + 'wal_receiver_create_temp_slot': Bool(130000, None), + 'wal_receiver_status_interval': Integer(90200, None, 0, 2147483, 's'), + 'wal_receiver_timeout': Integer(90200, None, 0, 2147483647, 'ms'), + 'wal_recycle': Bool(120000, None), + 'wal_retrieve_retry_interval': Integer(90500, None, 1, 2147483647, 'ms'), + 'wal_sender_timeout': Integer(90200, None, 0, 2147483647, 'ms'), + 'wal_skip_threshold': Integer(130000, None, 0, 2147483647, 'kB'), + 'wal_sync_method': Enum(90200, None, ('fsync', 'fdatasync', 'open_sync', 'open_datasync')), + 'wal_writer_delay': Integer(90200, None, 1, 10000, 'ms'), + 'wal_writer_flush_after': Integer(90600, None, 0, 2147483647, '8kB'), + 'work_mem': Integer(90200, None, 64, 2147483647, 'kB'), + 'xmlbinary': Enum(90200, None, ('base64', 'hex')), + 'xmloption': Enum(90200, None, ('content', 'document')), + 'zero_damaged_pages': Bool(90200, None) +}) + + +recovery_parameters = CaseInsensitiveDict({ + 'archive_cleanup_command': String(90200, None), + 'pause_at_recovery_target': Bool(90200, 90500), + 'primary_conninfo': String(90200, None), + 'primary_slot_name': String(90400, None), + 'promote_trigger_file': String(120000, None), + 'recovery_end_command': String(90200, None), + 'recovery_min_apply_delay': Integer(90200, None, 0, 2147483647, 'ms'), + 'recovery_target': Enum(90400, None, ('immediate', '')), + 'recovery_target_action': Enum(90500, None, ('pause', 'promote', 'shutdown')), + 'recovery_target_inclusive': Bool(90200, None), + 'recovery_target_lsn': String(100000, None), + 'recovery_target_name': String(90400, None), + 'recovery_target_time': String(90200, None), + 'recovery_target_timeline': String(90200, None), + 'recovery_target_xid': String(90200, None), + 'restore_command': String(90200, None), + 'standby_mode': Bool(90200, 120000), + 'trigger_file': String(90200, 120000) +}) + + +def _transform_parameter_value(validators, version, name, value): + validators = validators.get(name) + if validators: + for validator in (validators if isinstance(validators[0], tuple) else [validators]): + if version >= validator.version_from and\ + (validator.version_till is None or version < validator.version_till): + return validator.transform(name, value) + logger.warning('Removing unexpected parameter=%s value=%s from the config', name, value) + + +def transform_postgresql_parameter_value(version, name, value): + if '.' in name: + return value + return _transform_parameter_value(parameters, version, name, value) + + +def transform_recovery_parameter_value(version, name, value): + return _transform_parameter_value(recovery_parameters, version, name, value) diff --git a/patroni-for-openGauss/raft_controller.py b/patroni-for-openGauss/raft_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..5a9ae704e050c23d83fe6c154e9217ea48aa5533 --- /dev/null +++ b/patroni-for-openGauss/raft_controller.py @@ -0,0 +1,34 @@ +import logging +import os + +from patroni.daemon import AbstractPatroniDaemon, abstract_main +from patroni.dcs.raft import KVStoreTTL +from pysyncobj import SyncObjConf + +logger = logging.getLogger(__name__) + + +class RaftController(AbstractPatroniDaemon): + + def __init__(self, config): + super(RaftController, self).__init__(config) + + raft_config = self.config.get('raft') + self_addr = raft_config['self_addr'] + template = os.path.join(raft_config.get('data_dir', ''), self_addr) + self._syncobj_config = SyncObjConf(autoTick=False, appendEntriesUseBatch=False, dynamicMembershipChange=True, + journalFile=template + '.journal', fullDumpFile=template + '.dump') + self._raft = KVStoreTTL(self_addr, raft_config.get('partner_addrs', []), self._syncobj_config) + + def _run_cycle(self): + try: + self._raft.doTick(self._syncobj_config.autoTickPeriod) + except Exception: + logger.exception('doTick') + + def _shutdown(self): + self._raft.destroy() + + +def main(): + abstract_main(RaftController) diff --git a/patroni-for-openGauss/request.py b/patroni-for-openGauss/request.py new file mode 100644 index 0000000000000000000000000000000000000000..0463831c9c9d7a3c02c84c9f3e9672ccfcd4068f --- /dev/null +++ b/patroni-for-openGauss/request.py @@ -0,0 +1,58 @@ +import json +import urllib3 +import six + +from six.moves.urllib_parse import urlparse, urlunparse + +from .utils import USER_AGENT + + +class PatroniRequest(object): + + def __init__(self, config, insecure=False): + cert_reqs = 'CERT_NONE' if insecure or config.get('ctl', {}).get('insecure', False) else 'CERT_REQUIRED' + self._pool = urllib3.PoolManager(num_pools=10, maxsize=10, cert_reqs=cert_reqs) + self.reload_config(config) + + @staticmethod + def _get_cfg_value(config, name): + return config.get('ctl', {}).get(name) or config.get('restapi', {}).get(name) + + def _apply_pool_param(self, param, value): + if value: + self._pool.connection_pool_kw[param] = value + else: + self._pool.connection_pool_kw.pop(param, None) + + def _apply_ssl_file_param(self, config, name): + value = self._get_cfg_value(config, name + 'file') + self._apply_pool_param(name + '_file', value) + return value + + def reload_config(self, config): + self._pool.headers = urllib3.make_headers(basic_auth=self._get_cfg_value(config, 'auth'), user_agent=USER_AGENT) + + if self._apply_ssl_file_param(config, 'cert'): + self._apply_ssl_file_param(config, 'key') + else: + self._pool.connection_pool_kw.pop('key_file', None) + + cacert = config.get('ctl', {}).get('cacert') or config.get('restapi', {}).get('cafile') + self._apply_pool_param('ca_certs', cacert) + + def request(self, method, url, body=None, **kwargs): + if body is not None and not isinstance(body, six.string_types): + body = json.dumps(body) + return self._pool.request(method.upper(), url, body=body, **kwargs) + + def __call__(self, member, method='GET', endpoint=None, data=None, **kwargs): + url = member.api_url + if endpoint: + scheme, netloc, _, _, _, _ = urlparse(url) + url = urlunparse((scheme, netloc, endpoint, '', '', '')) + return self.request(method, url, data, **kwargs) + + +def get(url, verify=True, **kwargs): + http = PatroniRequest({}, not verify) + return http.request('GET', url, **kwargs) diff --git a/patroni-for-openGauss/scripts/__init__.py b/patroni-for-openGauss/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/patroni-for-openGauss/scripts/aws.py b/patroni-for-openGauss/scripts/aws.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7e0412929235966712cb620bb35201259b4746 --- /dev/null +++ b/patroni-for-openGauss/scripts/aws.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python + +import json +import logging +import sys +import boto.ec2 + +from patroni.utils import Retry, RetryFailedError +from patroni.request import get as requests_get + +logger = logging.getLogger(__name__) + + +class AWSConnection(object): + + def __init__(self, cluster_name): + self.available = False + self.cluster_name = cluster_name if cluster_name is not None else 'unknown' + self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=(boto.exception.StandardError,)) + try: + # get the instance id + r = requests_get('http://169.254.169.254/latest/dynamic/instance-identity/document', timeout=2.1) + except Exception: + logger.error('cannot query AWS meta-data') + return + + if r.status < 400: + try: + content = json.loads(r.data.decode('utf-8')) + self.instance_id = content['instanceId'] + self.region = content['region'] + except Exception: + logger.exception('unable to fetch instance id and region from AWS meta-data') + return + self.available = True + + def retry(self, *args, **kwargs): + return self._retry.copy()(*args, **kwargs) + + def aws_available(self): + return self.available + + def _tag_ebs(self, conn, role): + """ set tags, carrying the cluster name, instance role and instance id for the EBS storage """ + tags = {'Name': 'spilo_' + self.cluster_name, 'Role': role, 'Instance': self.instance_id} + volumes = conn.get_all_volumes(filters={'attachment.instance-id': self.instance_id}) + conn.create_tags([v.id for v in volumes], tags) + + def _tag_ec2(self, conn, role): + """ tag the current EC2 instance with a cluster role """ + tags = {'Role': role} + conn.create_tags([self.instance_id], tags) + + def on_role_change(self, new_role): + if not self.available: + return False + try: + conn = self.retry(boto.ec2.connect_to_region, self.region) + self.retry(self._tag_ec2, conn, new_role) + self.retry(self._tag_ebs, conn, new_role) + except RetryFailedError: + logger.warning("Unable to communicate to AWS " + "when setting tags for the EC2 instance {0} " + "and attached EBS volumes".format(self.instance_id)) + return False + return True + + +def main(): + logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) + if len(sys.argv) == 4 and sys.argv[1] in ('on_start', 'on_stop', 'on_role_change'): + AWSConnection(cluster_name=sys.argv[3]).on_role_change(sys.argv[2]) + else: + sys.exit("Usage: {0} action role name".format(sys.argv[0])) + + +if __name__ == '__main__': + main() diff --git a/patroni-for-openGauss/scripts/wale_restore.py b/patroni-for-openGauss/scripts/wale_restore.py new file mode 100644 index 0000000000000000000000000000000000000000..b2df8888be7ef2b1ebb884735e498bb2d3d1e669 --- /dev/null +++ b/patroni-for-openGauss/scripts/wale_restore.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python + +# sample script to clone new replicas using WAL-E restore +# falls back to pg_basebackup if WAL-E restore fails, or if +# WAL-E backup is too far behind +# note that pg_basebackup still expects to use restore from +# WAL-E for transaction logs + +# theoretically should work with SWIFT, but not tested on it + +# arguments are: +# - cluster scope +# - cluster role +# - master connection string +# - number of retries +# - envdir for the WALE env +# - WALE_BACKUP_THRESHOLD_MEGABYTES if WAL amount is above that - use pg_basebackup +# - WALE_BACKUP_THRESHOLD_PERCENTAGE if WAL size exceeds a certain percentage of the + +# this script depends on an envdir defining the S3 bucket (or SWIFT dir),and login +# credentials per WALE Documentation. + +# currently also requires that you configure the restore_command to use wal_e, example: +# recovery_conf: +# restore_command: envdir /etc/wal-e.d/env wal-e wal-fetch "%f" "%p" -p 1 +import argparse +import csv +import logging +import os +import psycopg2 +import subprocess +import sys +import time + +from collections import namedtuple + +logger = logging.getLogger(__name__) + +RETRY_SLEEP_INTERVAL = 1 +si_prefixes = ['K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] + + +# Meaningful names to the exit codes used by WALERestore +ExitCode = type('Enum', (), { + 'SUCCESS': 0, #: Succeeded + 'RETRY_LATER': 1, #: External issue, retry later + 'FAIL': 2 #: Don't try again unless configuration changes +}) + + +# We need to know the current PG version in order to figure out the correct WAL directory name +def get_major_version(data_dir): + version_file = os.path.join(data_dir, 'PG_VERSION') + if os.path.isfile(version_file): # version file exists + try: + with open(version_file) as f: + return float(f.read()) + except Exception: + logger.exception('Failed to read PG_VERSION from %s', data_dir) + return 0.0 + + +def repr_size(n_bytes): + """ + >>> repr_size(1000) + '1000 Bytes' + >>> repr_size(8257332324597) + '7.5 TiB' + """ + if n_bytes < 1024: + return '{0} Bytes'.format(n_bytes) + i = -1 + while n_bytes > 1023: + n_bytes /= 1024.0 + i += 1 + return '{0} {1}iB'.format(round(n_bytes, 1), si_prefixes[i]) + + +def size_as_bytes(size_, prefix): + """ + >>> size_as_bytes(7.5, 'T') + 8246337208320 + """ + prefix = prefix.upper() + + assert prefix in si_prefixes + + exponent = si_prefixes.index(prefix) + 1 + + return int(size_ * (1024.0 ** exponent)) + + +WALEConfig = namedtuple( + 'WALEConfig', + [ + 'env_dir', + 'threshold_mb', + 'threshold_pct', + 'cmd', + ] +) + + +class WALERestore(object): + def __init__(self, scope, datadir, connstring, env_dir, threshold_mb, + threshold_pct, use_iam, no_master, retries): + self.scope = scope + self.master_connection = connstring + self.data_dir = datadir + self.no_master = no_master + + wale_cmd = [ + 'envdir', + env_dir, + 'wal-e', + ] + + if use_iam == 1: + wale_cmd += ['--aws-instance-profile'] + + self.wal_e = WALEConfig( + env_dir=env_dir, + threshold_mb=threshold_mb, + threshold_pct=threshold_pct, + cmd=wale_cmd, + ) + + self.init_error = (not os.path.exists(self.wal_e.env_dir)) + self.retries = retries + + def run(self): + """ + Creates a new replica using WAL-E + + Returns + ------- + ExitCode + 0 = Success + 1 = Error, try again + 2 = Error, don't try again + + """ + if self.init_error: + logger.error('init error: %r did not exist at initialization time', + self.wal_e.env_dir) + return ExitCode.FAIL + + try: + should_use_s3 = self.should_use_s3_to_create_replica() + if should_use_s3 is None: # Need to retry + return ExitCode.RETRY_LATER + elif should_use_s3: + return self.create_replica_with_s3() + elif not should_use_s3: + return ExitCode.FAIL + except Exception: + logger.exception("Unhandled exception when running WAL-E restore") + return ExitCode.FAIL + + def should_use_s3_to_create_replica(self): + """ determine whether it makes sense to use S3 and not pg_basebackup """ + + threshold_megabytes = self.wal_e.threshold_mb + threshold_percent = self.wal_e.threshold_pct + + try: + cmd = self.wal_e.cmd + ['backup-list', '--detail', 'LATEST'] + + logger.debug('calling %r', cmd) + wale_output = subprocess.check_output(cmd) + + reader = csv.DictReader(wale_output.decode('utf-8').splitlines(), + dialect='excel-tab') + rows = list(reader) + if not len(rows): + logger.warning('wal-e did not find any backups') + return False + + # This check might not add much, it was performed in the previous + # version of this code. since the old version rolled CSV parsing the + # check may have been part of the CSV parsing. + if len(rows) > 1: + logger.warning( + 'wal-e returned more than one row of backups: %r', + rows) + return False + + backup_info = rows[0] + except subprocess.CalledProcessError: + logger.exception("could not query wal-e latest backup") + return None + + try: + backup_size = int(backup_info['expanded_size_bytes']) + backup_start_segment = backup_info['wal_segment_backup_start'] + backup_start_offset = backup_info['wal_segment_offset_backup_start'] + except KeyError: + logger.exception("unable to get some of WALE backup parameters") + return None + + # WAL filename is XXXXXXXXYYYYYYYY000000ZZ, where X - timeline, Y - LSN logical log file, + # ZZ - 2 high digits of LSN offset. The rest of the offset is the provided decimal offset, + # that we have to convert to hex and 'prepend' to the high offset digits. + + lsn_segment = backup_start_segment[8:16] + # first 2 characters of the result are 0x and the last one is L + lsn_offset = hex((int(backup_start_segment[16:32], 16) << 24) + int(backup_start_offset))[2:-1] + + # construct the LSN from the segment and offset + backup_start_lsn = '{0}/{1}'.format(lsn_segment, lsn_offset) + + diff_in_bytes = backup_size + attempts_no = 0 + while True: + if self.master_connection: + try: + # get the difference in bytes between the current WAL location and the backup start offset + with psycopg2.connect(self.master_connection) as con: + if con.server_version >= 100000: + wal_name = 'wal' + lsn_name = 'lsn' + else: + wal_name = 'xlog' + lsn_name = 'location' + con.autocommit = True + with con.cursor() as cur: + cur.execute(("SELECT CASE WHEN pg_catalog.pg_is_in_recovery()" + " THEN GREATEST(pg_catalog.pg_{0}_{1}_diff(COALESCE(" + "pg_last_{0}_receive_{1}(), '0/0'), %s)::bigint, " + "pg_catalog.pg_{0}_{1}_diff(split_part(left(pg_catalog.pg_last_{0}_replay_{1}()::text,-1),',',2), %s)::bigint)" + " ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint" + " END").format(wal_name, lsn_name), + (backup_start_lsn, backup_start_lsn, backup_start_lsn)) + + diff_in_bytes = int(cur.fetchone()[0]) + except psycopg2.Error: + logger.exception('could not determine difference with the master location') + if attempts_no < self.retries: # retry in case of a temporarily connection issue + attempts_no = attempts_no + 1 + time.sleep(RETRY_SLEEP_INTERVAL) + continue + else: + if not self.no_master: + return False # do no more retries on the outer level + logger.info("continue with base backup from S3 since master is not available") + diff_in_bytes = 0 + break + else: + # always try to use WAL-E if master connection string is not available + diff_in_bytes = 0 + break + + # if the size of the accumulated WAL segments is more than a certan percentage of the backup size + # or exceeds the pre-determined size - pg_basebackup is chosen instead. + is_size_thresh_ok = diff_in_bytes < int(threshold_megabytes) * 1048576 + threshold_pct_bytes = backup_size * threshold_percent / 100.0 + is_percentage_thresh_ok = float(diff_in_bytes) < int(threshold_pct_bytes) + are_thresholds_ok = is_size_thresh_ok and is_percentage_thresh_ok + + class Size(object): + def __init__(self, n_bytes, prefix=None): + self.n_bytes = n_bytes + self.prefix = prefix + + def __repr__(self): + if self.prefix is not None: + n_bytes = size_as_bytes(self.n_bytes, self.prefix) + else: + n_bytes = self.n_bytes + return repr_size(n_bytes) + + class HumanContext(object): + def __init__(self, items): + self.items = items + + def __repr__(self): + return ', '.join('{}={!r}'.format(key, value) + for key, value in self.items) + + human_context = repr(HumanContext([ + ('threshold_size', Size(threshold_megabytes, 'M')), + ('threshold_percent', threshold_percent), + ('threshold_percent_size', Size(threshold_pct_bytes)), + ('backup_size', Size(backup_size)), + ('backup_diff', Size(diff_in_bytes)), + ('is_size_thresh_ok', is_size_thresh_ok), + ('is_percentage_thresh_ok', is_percentage_thresh_ok), + ])) + + if not are_thresholds_ok: + logger.info('wal-e backup size diff is over threshold, falling back ' + 'to other means of restore: %s', human_context) + else: + logger.info('Thresholds are OK, using wal-e basebackup: %s', human_context) + return are_thresholds_ok + + def fix_subdirectory_path_if_broken(self, dirname): + # in case it is a symlink pointing to a non-existing location, remove it and create the actual directory + path = os.path.join(self.data_dir, dirname) + if not os.path.exists(path): + if os.path.islink(path): # broken xlog symlink, to remove + try: + os.remove(path) + except OSError: + logger.exception("could not remove broken %s symlink pointing to %s", + dirname, os.readlink(path)) + return False + try: + os.mkdir(path) + except OSError: + logger.exception("coud not create missing %s directory path", dirname) + return False + return True + + def create_replica_with_s3(self): + # if we're set up, restore the replica using fetch latest + try: + cmd = self.wal_e.cmd + ['backup-fetch', + '{}'.format(self.data_dir), + 'LATEST'] + logger.debug('calling: %r', cmd) + exit_code = subprocess.call(cmd) + except Exception as e: + logger.error('Error when fetching backup with WAL-E: {0}'.format(e)) + return ExitCode.RETRY_LATER + + if (exit_code == 0 and not + self.fix_subdirectory_path_if_broken('pg_xlog' if get_major_version(self.data_dir) < 10 else 'pg_wal')): + return ExitCode.FAIL + return exit_code + + +def main(): + logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) + parser = argparse.ArgumentParser(description='Script to image replicas using WAL-E') + parser.add_argument('--scope', required=True) + parser.add_argument('--role', required=False) + parser.add_argument('--datadir', required=True) + parser.add_argument('--connstring', required=True) + parser.add_argument('--retries', type=int, default=1) + parser.add_argument('--envdir', required=True) + parser.add_argument('--threshold_megabytes', type=int, default=10240) + parser.add_argument('--threshold_backup_size_percentage', type=int, default=30) + parser.add_argument('--use_iam', type=int, default=0) + parser.add_argument('--no_master', type=int, default=0) + args = parser.parse_args() + + exit_code = None + assert args.retries >= 0 + + # Retry cloning in a loop. We do separate retries for the master + # connection attempt inside should_use_s3_to_create_replica, + # because we need to differentiate between the last attempt and + # the rest and make a decision when the last attempt fails on + # whether to use WAL-E or not depending on the no_master flag. + for _ in range(0, args.retries + 1): + restore = WALERestore(scope=args.scope, datadir=args.datadir, connstring=args.connstring, + env_dir=args.envdir, threshold_mb=args.threshold_megabytes, + threshold_pct=args.threshold_backup_size_percentage, use_iam=args.use_iam, + no_master=args.no_master, retries=args.retries) + exit_code = restore.run() + if not exit_code == ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry + logger.debug('exit_code is %r, not retrying', exit_code) + break + time.sleep(RETRY_SLEEP_INTERVAL) + + return exit_code + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/patroni-for-openGauss/utils.py b/patroni-for-openGauss/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0bbddb4d217809b6222fb0db2dca41f35703dd5e --- /dev/null +++ b/patroni-for-openGauss/utils.py @@ -0,0 +1,534 @@ +import errno +import json.decoder as json_decoder +import logging +import os +import platform +import random +import re +import socket +import sys +import tempfile +import time + +from dateutil import tz + +from .exceptions import PatroniException +from .version import __version__ + +tzutc = tz.tzutc() + +logger = logging.getLogger(__name__) + +USER_AGENT = 'Patroni/{0} Python/{1} {2}'.format(__version__, platform.python_version(), platform.system()) +OCT_RE = re.compile(r'^[-+]?0[0-7]*') +DEC_RE = re.compile(r'^[-+]?(0|[1-9][0-9]*)') +HEX_RE = re.compile(r'^[-+]?0x[0-9a-fA-F]+') +DBL_RE = re.compile(r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?') + + +def deep_compare(obj1, obj2): + """ + >>> deep_compare({'1': None}, {}) + False + >>> deep_compare({'1': {}}, {'1': None}) + False + >>> deep_compare({'1': [1]}, {'1': [2]}) + False + >>> deep_compare({'1': 2}, {'1': '2'}) + True + >>> deep_compare({'1': {'2': [3, 4]}}, {'1': {'2': [3, 4]}}) + True + """ + + if set(list(obj1.keys())) != set(list(obj2.keys())): # Objects have different sets of keys + return False + + for key, value in obj1.items(): + if isinstance(value, dict): + if not (isinstance(obj2[key], dict) and deep_compare(value, obj2[key])): + return False + elif str(value) != str(obj2[key]): + return False + return True + + +def patch_config(config, data): + """recursively 'patch' `config` with `data` + :returns: `!True` if the `config` was changed""" + is_changed = False + for name, value in data.items(): + if value is None: + if config.pop(name, None) is not None: + is_changed = True + elif name in config: + if isinstance(value, dict): + if isinstance(config[name], dict): + if patch_config(config[name], value): + is_changed = True + else: + config[name] = value + is_changed = True + elif str(config[name]) != str(value): + config[name] = value + is_changed = True + else: + config[name] = value + is_changed = True + return is_changed + + +def parse_bool(value): + """ + >>> parse_bool(1) + True + >>> parse_bool('off') + False + >>> parse_bool('foo') + """ + value = str(value).lower() + if value in ('on', 'true', 'yes', '1'): + return True + if value in ('off', 'false', 'no', '0'): + return False + + +def strtol(value, strict=True): + """As most as possible close equivalent of strtol(3) function (with base=0), + used by postgres to parse parameter values. + >>> strtol(0) == (0, '') + True + >>> strtol(1) == (1, '') + True + >>> strtol(9) == (9, '') + True + >>> strtol(' +0x400MB') == (1024, 'MB') + True + >>> strtol(' -070d') == (-56, 'd') + True + >>> strtol(' d ') == (None, 'd') + True + >>> strtol(' 1 d ') == (1, ' d') + True + >>> strtol('9s', False) == (9, 's') + True + >>> strtol(' s ', False) == (1, 's') + True + """ + value = str(value).strip() + for regex, base in ((HEX_RE, 16), (OCT_RE, 8), (DEC_RE, 10)): + match = regex.match(value) + if match: + end = match.end() + return int(value[:end], base), value[end:] + return (None if strict else 1), value + + +def strtod(value): + """As most as possible close equivalent of strtod(3) function used by postgres to parse parameter values. + >>> strtod(' A ') == (None, 'A') + True + """ + value = str(value).strip() + match = DBL_RE.match(value) + if match: + end = match.end() + return float(value[:end]), value[end:] + return None, value + + +def rint(value): + """ + >>> rint(0.5) == 0 + True + >>> rint(0.501) == 1 + True + >>> rint(1.5) == 2 + True + """ + + ret = round(value) + return 2.0 * round(value / 2.0) if abs(ret - value) == 0.5 else ret + + +def convert_to_base_unit(value, unit, base_unit): + convert = { + 'B': {'B': 1, 'kB': 1024, 'MB': 1024 * 1024, 'GB': 1024 * 1024 * 1024, 'TB': 1024 * 1024 * 1024 * 1024}, + 'kB': {'B': 1.0 / 1024, 'kB': 1, 'MB': 1024, 'GB': 1024 * 1024, 'TB': 1024 * 1024 * 1024}, + 'MB': {'B': 1.0 / (1024 * 1024), 'kB': 1.0 / 1024, 'MB': 1, 'GB': 1024, 'TB': 1024 * 1024}, + 'ms': {'us': 1.0 / 1000, 'ms': 1, 's': 1000, 'min': 1000 * 60, 'h': 1000 * 60 * 60, 'd': 1000 * 60 * 60 * 24}, + 's': {'us': 1.0 / (1000 * 1000), 'ms': 1.0 / 1000, 's': 1, 'min': 60, 'h': 60 * 60, 'd': 60 * 60 * 24}, + 'min': {'us': 1.0 / (1000 * 1000 * 60), 'ms': 1.0 / (1000 * 60), 's': 1.0 / 60, 'min': 1, 'h': 60, 'd': 60 * 24} + } + + round_order = { + 'TB': 'GB', 'GB': 'MB', 'MB': 'kB', 'kB': 'B', + 'd': 'h', 'h': 'min', 'min': 's', 's': 'ms', 'ms': 'us' + } + + if base_unit and base_unit not in convert: + base_value, base_unit = strtol(base_unit, False) + else: + base_value = 1 + + if base_unit in convert and unit in convert[base_unit]: + value *= convert[base_unit][unit] / float(base_value) + + if unit in round_order: + multiplier = convert[base_unit][round_order[unit]] + value = rint(value / float(multiplier)) * multiplier + + return value + + +def parse_int(value, base_unit=None): + """ + >>> parse_int('1') == 1 + True + >>> parse_int(' 0x400 MB ', '16384kB') == 64 + True + >>> parse_int('1MB', 'kB') == 1024 + True + >>> parse_int('1000 ms', 's') == 1 + True + >>> parse_int('1TB', 'GB') is None + True + >>> parse_int(0) == 0 + True + >>> parse_int('6GB', '16MB') == 384 + True + >>> parse_int('4097.4kB', 'kB') == 4097 + True + >>> parse_int('4097.5kB', 'kB') == 4098 + True + """ + + val, unit = strtol(value) + if val is None and unit.startswith('.') or unit and unit[0] in ('.', 'e', 'E'): + val, unit = strtod(value) + + if val is not None: + unit = unit.strip() + if not unit: + return int(rint(val)) + + val = convert_to_base_unit(val, unit, base_unit) + if val is not None: + return int(rint(val)) + + +def parse_real(value, base_unit=None): + """ + >>> parse_real(' +0.0005 ') == 0.0005 + True + >>> parse_real('0.0005ms', 'ms') == 0.0 + True + >>> parse_real('0.00051ms', 'ms') == 0.001 + True + """ + val, unit = strtod(value) + + if val is not None: + unit = unit.strip() + if not unit: + return val + + return convert_to_base_unit(val, unit, base_unit) + + +def compare_values(vartype, unit, old_value, new_value): + """ + >>> compare_values('enum', None, 'remote_write', 'REMOTE_WRITE') + True + >>> compare_values('real', None, '1e-06', 0.000001) + True + """ + + converters = { + 'bool': lambda v1, v2: parse_bool(v1), + 'integer': parse_int, + 'real': parse_real, + 'enum': lambda v1, v2: str(v1).lower(), + 'string': lambda v1, v2: str(v1) + } + + convert = converters.get(vartype) or converters['string'] + old_value = convert(old_value, None) + new_value = convert(new_value, unit) + + return old_value is not None and new_value is not None and old_value == new_value + + +def _sleep(interval): + time.sleep(interval) + + +class RetryFailedError(PatroniException): + + """Raised when retrying an operation ultimately failed, after retrying the maximum number of attempts.""" + + +class Retry(object): + + """Helper for retrying a method in the face of retry-able exceptions""" + + def __init__(self, max_tries=1, delay=0.1, backoff=2, max_jitter=0.8, max_delay=3600, + sleep_func=_sleep, deadline=None, retry_exceptions=PatroniException): + """Create a :class:`Retry` instance for retrying function calls + + :param max_tries: How many times to retry the command. -1 means infinite tries. + :param delay: Initial delay between retry attempts. + :param backoff: Backoff multiplier between retry attempts. Defaults to 2 for exponential backoff. + :param max_jitter: Additional max jitter period to wait between retry attempts to avoid slamming the server. + :param max_delay: Maximum delay in seconds, regardless of other backoff settings. Defaults to one hour. + :param retry_exceptions: single exception or tuple""" + + self.max_tries = max_tries + self.delay = delay + self.backoff = backoff + self.max_jitter = int(max_jitter * 100) + self.max_delay = float(max_delay) + self._attempts = 0 + self._cur_delay = delay + self.deadline = deadline + self._cur_stoptime = None + self.sleep_func = sleep_func + self.retry_exceptions = retry_exceptions + + def reset(self): + """Reset the attempt counter""" + self._attempts = 0 + self._cur_delay = self.delay + self._cur_stoptime = None + + def copy(self): + """Return a clone of this retry manager""" + return Retry(max_tries=self.max_tries, delay=self.delay, backoff=self.backoff, + max_jitter=self.max_jitter / 100.0, max_delay=self.max_delay, sleep_func=self.sleep_func, + deadline=self.deadline, retry_exceptions=self.retry_exceptions) + + @property + def sleeptime(self): + return self._cur_delay + (random.randint(0, self.max_jitter) / 100.0) + + def update_delay(self): + self._cur_delay = min(self._cur_delay * self.backoff, self.max_delay) + + @property + def stoptime(self): + return self._cur_stoptime + + def __call__(self, func, *args, **kwargs): + """Call a function with arguments until it completes without throwing a `retry_exceptions` + + :param func: Function to call + :param args: Positional arguments to call the function with + :params kwargs: Keyword arguments to call the function with + + The function will be called until it doesn't throw one of the retryable exceptions""" + self.reset() + + while True: + try: + if self.deadline is not None and self._cur_stoptime is None: + self._cur_stoptime = time.time() + self.deadline + return func(*args, **kwargs) + except self.retry_exceptions as e: + # Note: max_tries == -1 means infinite tries. + if self._attempts == self.max_tries: + logger.warning('Retry got exception: %s', e) + raise RetryFailedError("Too many retry attempts") + self._attempts += 1 + sleeptime = hasattr(e, 'sleeptime') and e.sleeptime or self.sleeptime + + if self._cur_stoptime is not None and time.time() + sleeptime >= self._cur_stoptime: + logger.warning('Retry got exception: %s', e) + raise RetryFailedError("Exceeded retry deadline") + logger.debug('Retry got exception: %s', e) + self.sleep_func(sleeptime) + self.update_delay() + + +def polling_loop(timeout, interval=1): + """Returns an iterator that returns values until timeout has passed. Timeout is measured from start of iteration.""" + start_time = time.time() + iteration = 0 + end_time = start_time + timeout + while time.time() < end_time: + yield iteration + iteration += 1 + time.sleep(interval) + + +def split_host_port(value, default_port): + t = value.rsplit(':', 1) + if ':' in t[0]: + t[0] = t[0].strip('[]') + t.append(default_port) + return t[0], int(t[1]) + + +def uri(proto, netloc, path='', user=None): + host, port = netloc if isinstance(netloc, (list, tuple)) else split_host_port(netloc, 0) + if host and ':' in host and host[0] != '[' and host[-1] != ']': + host = '[{0}]'.format(host) + port = ':{0}'.format(port) if port else '' + path = '/{0}'.format(path) if path and not path.startswith('/') else path + user = '{0}@'.format(user) if user else '' + return '{0}://{1}{2}{3}{4}'.format(proto, user, host, port, path) + + +def iter_response_objects(response): + prev = '' + decoder = json_decoder.JSONDecoder() + for chunk in response.read_chunked(decode_content=False): + if isinstance(chunk, bytes): + chunk = chunk.decode('utf-8') + chunk = prev + chunk + + length = len(chunk) + idx = json_decoder.WHITESPACE.match(chunk, 0).end() + while idx < length: + try: + message, idx = decoder.raw_decode(chunk, idx) + except ValueError: # malformed or incomplete JSON, unlikely to happen + break + else: + yield message + idx = json_decoder.WHITESPACE.match(chunk, idx).end() + prev = chunk[idx:] + + +def is_standby_cluster(config): + # Check whether or not provided configuration describes a standby cluster + return isinstance(config, dict) and (config.get('host') or config.get('port') or config.get('restore_command')) + + +def cluster_as_json(cluster): + leader_name = cluster.leader.name if cluster.leader else None + xlog_location_cluster = cluster.last_leader_operation or 0 + + ret = {'members': []} + for m in cluster.members: + if m.name == leader_name: + config = cluster.config.data if cluster.config and cluster.config.modify_index else {} + role = 'standby_leader' if is_standby_cluster(config.get('standby_cluster')) else 'leader' + elif m.name in cluster.sync.members: + role = 'sync_standby' + else: + role = 'replica' + + member = {'name': m.name, 'role': role, 'state': m.data.get('state', ''), 'api_url': m.api_url} + conn_kwargs = m.conn_kwargs() + if conn_kwargs.get('host'): + member['host'] = conn_kwargs['host'] + if conn_kwargs.get('port'): + member['port'] = int(conn_kwargs['port']) + optional_attributes = ('timeline', 'pending_restart', 'scheduled_restart', 'tags') + member.update({n: m.data[n] for n in optional_attributes if n in m.data}) + + if m.name != leader_name: + xlog_location = m.data.get('xlog_location') + if xlog_location is None: + member['lag'] = 'unknown' + elif xlog_location_cluster >= xlog_location: + member['lag'] = xlog_location_cluster - xlog_location + else: + member['lag'] = 0 + + ret['members'].append(member) + + # sort members by name for consistency + ret['members'].sort(key=lambda m: m['name']) + if cluster.is_paused(): + ret['pause'] = True + if cluster.failover and cluster.failover.scheduled_at: + ret['scheduled_switchover'] = {'at': cluster.failover.scheduled_at.isoformat()} + if cluster.failover.leader: + ret['scheduled_switchover']['from'] = cluster.failover.leader + if cluster.failover.candidate: + ret['scheduled_switchover']['to'] = cluster.failover.candidate + return ret + + +def is_subpath(d1, d2): + real_d1 = os.path.realpath(d1) + os.path.sep + real_d2 = os.path.realpath(os.path.join(real_d1, d2)) + return os.path.commonprefix([real_d1, real_d2 + os.path.sep]) == real_d1 + + +def validate_directory(d, msg="{} {}"): + if not os.path.exists(d): + try: + os.makedirs(d) + except OSError as e: + logger.error(e) + if e.errno != errno.EEXIST: + raise PatroniException(msg.format(d, "couldn't create the directory")) + elif os.path.isdir(d): + try: + fd, tmpfile = tempfile.mkstemp(dir=d) + os.close(fd) + os.remove(tmpfile) + except OSError: + raise PatroniException(msg.format(d, "the directory is not writable")) + else: + raise PatroniException(msg.format(d, "is not a directory")) + + +def data_directory_is_empty(data_dir): + if not os.path.exists(data_dir): + return True + return all(os.name != 'nt' and (n.startswith('.') or n == 'lost+found') for n in os.listdir(data_dir)) + + +def keepalive_intvl(timeout, idle, cnt=3): + return max(1, int(float(timeout - idle) / cnt)) + + +def keepalive_socket_options(timeout, idle, cnt=3): + yield (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + if sys.platform.startswith('linux'): + yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT + TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None) + TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None) + TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None) + elif sys.platform.startswith('darwin'): + TCP_KEEPIDLE = 0x10 # (named "TCP_KEEPALIVE" in C) + TCP_KEEPINTVL = 0x101 + TCP_KEEPCNT = 0x102 + else: + return + + intvl = keepalive_intvl(timeout, idle, cnt) + yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle) + yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl) + yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt) + + +def enable_keepalive(sock, timeout, idle, cnt=3): + SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None) + if SIO_KEEPALIVE_VALS is not None: # Windows + intvl = keepalive_intvl(timeout, idle, cnt) + return sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000)) + + for opt in keepalive_socket_options(timeout, idle, cnt): + sock.setsockopt(*opt) + + +def find_executable(executable, path=None): + _, ext = os.path.splitext(executable) + + if (sys.platform == 'win32') and (ext != '.exe'): + executable = executable + '.exe' + + if os.path.isfile(executable): + return executable + + if path is None: + path = os.environ.get('PATH', os.defpath) + + for p in path.split(os.pathsep): + f = os.path.join(p, executable) + if os.path.isfile(f): + return f diff --git a/patroni-for-openGauss/validator.py b/patroni-for-openGauss/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..f73ff4ac2c403ff353679ca934b8ea0af3897b90 --- /dev/null +++ b/patroni-for-openGauss/validator.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python3 +import os +import socket +import re +import subprocess + +from six import string_types + +from .utils import find_executable, split_host_port, data_directory_is_empty +from .dcs import dcs_modules +from .exceptions import ConfigParseError + + +def data_directory_empty(data_dir): + if os.path.isfile(os.path.join(data_dir, "global", "pg_control")): + return False + return data_directory_is_empty(data_dir) + + +def validate_connect_address(address): + try: + host, _ = split_host_port(address, 1) + except (AttributeError, TypeError, ValueError): + raise ConfigParseError("contains a wrong value") + if host in ["127.0.0.1", "0.0.0.0", "*", "::1", "localhost"]: + raise ConfigParseError('must not contain "127.0.0.1", "0.0.0.0", "*", "::1", "localhost"') + return True + + +def validate_host_port(host_port, listen=False, multiple_hosts=False): + try: + hosts, port = split_host_port(host_port, None) + except (ValueError, TypeError): + raise ConfigParseError("contains a wrong value") + else: + if multiple_hosts: + hosts = hosts.split(",") + else: + hosts = [hosts] + for host in hosts: + proto = socket.getaddrinfo(host, "", 0, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + s = socket.socket(proto[0][0], socket.SOCK_STREAM) + try: + if s.connect_ex((host, port)) == 0: + if listen: + raise ConfigParseError("Port {} is already in use.".format(port)) + elif not listen: + raise ConfigParseError("{} is not reachable".format(host_port)) + except socket.gaierror as e: + raise ConfigParseError(e) + finally: + s.close() + return True + + +def validate_host_port_list(value): + assert all([validate_host_port(v) for v in value]), "didn't pass the validation" + return True + + +def comma_separated_host_port(string): + return validate_host_port_list([s.strip() for s in string.split(",")]) + + +def validate_host_port_listen(host_port): + return validate_host_port(host_port, listen=True) + + +def validate_host_port_listen_multiple_hosts(host_port): + return validate_host_port(host_port, listen=True, multiple_hosts=True) + + +def is_ipv4_address(ip): + try: + socket.inet_aton(ip) + except Exception: + raise ConfigParseError("Is not a valid ipv4 address") + return True + + +def is_ipv6_address(ip): + try: + socket.inet_pton(socket.AF_INET6, ip) + except Exception: + raise ConfigParseError("Is not a valid ipv6 address") + return True + + +def get_major_version(bin_dir=None): + if not bin_dir: + binary = 'gaussdb' + else: + binary = os.path.join(bin_dir, 'gaussdb') + version = subprocess.check_output([binary, '--version']).decode() + version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version) + return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1) + + +def validate_data_dir(data_dir): + if not data_dir: + raise ConfigParseError("is an empty string") + elif os.path.exists(data_dir) and not os.path.isdir(data_dir): + raise ConfigParseError("is not a directory") + elif not data_directory_empty(data_dir): + if not os.path.exists(os.path.join(data_dir, "PG_VERSION")): + raise ConfigParseError("doesn't look like a valid data directory") + else: + with open(os.path.join(data_dir, "PG_VERSION"), "r") as version: + pgversion = version.read().strip() + waldir = ("pg_wal" if float(pgversion) >= 10 else "pg_xlog") + if not os.path.isdir(os.path.join(data_dir, waldir)): + raise ConfigParseError("data dir for the cluster is not empty, but doesn't contain" + " \"{}\" directory".format(waldir)) + bin_dir = schema.data.get("postgresql", {}).get("bin_dir", None) + major_version = get_major_version(bin_dir) + if pgversion != major_version: + raise ConfigParseError("data_dir directory postgresql version ({}) doesn't match with " + "'postgres --version' output ({})".format(pgversion, major_version)) + return True + + +class Result(object): + def __init__(self, status, error="didn't pass validation", level=0, path="", data=""): + self.status = status + self.path = path + self.data = data + self.level = level + self._error = error + if not self.status: + self.error = error + else: + self.error = None + + def __repr__(self): + return self.path + (" " + str(self.data) + " " + self._error if self.error else "") + + +class Case(object): + def __init__(self, schema): + self._schema = schema + + +class Or(object): + def __init__(self, *args): + self.args = args + + +class Optional(object): + def __init__(self, name): + self.name = name + + +class Directory(object): + def __init__(self, contains=None, contains_executable=None): + self.contains = contains + self.contains_executable = contains_executable + + def validate(self, name): + if not name: + yield Result(False, "is an empty string") + elif not os.path.exists(name): + yield Result(False, "Directory '{}' does not exist.".format(name)) + elif not os.path.isdir(name): + yield Result(False, "'{}' is not a directory.".format(name)) + else: + if self.contains: + for path in self.contains: + if not os.path.exists(os.path.join(name, path)): + yield Result(False, "'{}' does not contain '{}'".format(name, path)) + if self.contains_executable: + for program in self.contains_executable: + if not find_executable(program, name): + yield Result(False, "'{}' does not contain '{}'".format(name, program)) + + +class Schema(object): + def __init__(self, validator): + self.validator = validator + + def __call__(self, data): + for i in self.validate(data): + if not i.status: + print(i) + + def validate(self, data): + self.data = data + if isinstance(self.validator, string_types): + yield Result(isinstance(self.data, string_types), "is not a string", level=1, data=self.data) + elif issubclass(type(self.validator), type): + validator = self.validator + if self.validator == str: + validator = string_types + yield Result(isinstance(self.data, validator), + "is not {}".format(_get_type_name(self.validator)), level=1, data=self.data) + elif callable(self.validator): + if hasattr(self.validator, "expected_type"): + if not isinstance(data, self.validator.expected_type): + yield Result(False, "is not {}" + .format(_get_type_name(self.validator.expected_type)), level=1, data=self.data) + return + try: + self.validator(data) + yield Result(True, data=self.data) + except Exception as e: + yield Result(False, "didn't pass validation: {}".format(e), data=self.data) + elif isinstance(self.validator, dict): + if not len(self.validator): + yield Result(isinstance(self.data, dict), "is not a dictionary", level=1, data=self.data) + elif isinstance(self.validator, list): + if not isinstance(self.data, list): + yield Result(isinstance(self.data, list), "is not a list", level=1, data=self.data) + return + for i in self.iter(): + yield i + + def iter(self): + if isinstance(self.validator, dict): + if not isinstance(self.data, dict): + yield Result(False, "is not a dictionary.", level=1) + else: + for i in self.iter_dict(): + yield i + elif isinstance(self.validator, list): + if len(self.data) == 0: + yield Result(False, "is an empty list", data=self.data) + if len(self.validator) > 0: + for key, value in enumerate(self.data): + for v in Schema(self.validator[0]).validate(value): + yield Result(v.status, v.error, + path=(str(key) + ("." + v.path if v.path else "")), level=v.level, data=value) + elif isinstance(self.validator, Directory): + for v in self.validator.validate(self.data): + yield v + elif isinstance(self.validator, Or): + for i in self.iter_or(): + yield i + + def iter_dict(self): + for key in self.validator.keys(): + for d in self._data_key(key): + if d not in self.data and not isinstance(key, Optional): + yield Result(False, "is not defined.", path=d) + elif d not in self.data and isinstance(key, Optional): + continue + else: + validator = self.validator[key] + if isinstance(key, Or) and isinstance(self.validator[key], Case): + validator = self.validator[key]._schema[d] + for v in Schema(validator).validate(self.data[d]): + yield Result(v.status, v.error, + path=(d + ("." + v.path if v.path else "")), level=v.level, data=v.data) + + def iter_or(self): + results = [] + for a in self.validator.args: + r = [] + for v in Schema(a).validate(self.data): + r.append(v) + if any([x.status for x in r]) and not all([x.status for x in r]): + results += filter(lambda x: not x.status, r) + else: + results += r + if not any([x.status for x in results]): + max_level = 3 + for v in sorted(results, key=lambda x: x.level): + if v.level > max_level: + break + max_level = v.level + yield Result(v.status, v.error, path=v.path, level=v.level, data=v.data) + + def _data_key(self, key): + if isinstance(self.data, dict) and isinstance(key, str): + yield key + elif isinstance(key, Optional): + yield key.name + elif isinstance(key, Or): + if any([i in self.data for i in key.args]): + for i in key.args: + if i in self.data: + yield i + else: + for i in key.args: + yield i + + +def _get_type_name(python_type): + return {str: 'a string', int: 'and integer', float: 'a number', bool: 'a boolean', + list: 'an array', dict: 'a dictionary', string_types: "a string"}.get( + python_type, getattr(python_type, __name__, "unknown type")) + + +def assert_(condition, message="Wrong value"): + assert condition, message + + +userattributes = {"username": "", Optional("password"): ""} +available_dcs = [m.split(".")[-1] for m in dcs_modules()] +validate_host_port_list.expected_type = list +comma_separated_host_port.expected_type = string_types +validate_connect_address.expected_type = string_types +validate_host_port_listen.expected_type = string_types +validate_host_port_listen_multiple_hosts.expected_type = string_types +validate_data_dir.expected_type = string_types +validate_etcd = { + Or("host", "hosts", "srv", "url", "proxy"): Case({ + "host": validate_host_port, + "hosts": Or(comma_separated_host_port, [validate_host_port]), + "srv": str, + "url": str, + "proxy": str}) +} + +schema = Schema({ + "name": str, + "scope": str, + "restapi": { + "listen": validate_host_port_listen, + "connect_address": validate_connect_address + }, + Optional("bootstrap"): { + "dcs": { + Optional("ttl"): int, + Optional("loop_wait"): int, + Optional("retry_timeout"): int, + Optional("maximum_lag_on_failover"): int + }, + "pg_hba": [str], + "initdb": [Or(str, dict)] + }, + Or(*available_dcs): Case({ + "consul": { + Or("host", "url"): Case({ + "host": validate_host_port, + "url": str}) + }, + "etcd": validate_etcd, + "etcd3": validate_etcd, + "exhibitor": { + "hosts": [str], + "port": lambda i: assert_(int(i) <= 65535), + Optional("pool_interval"): int + }, + "raft": { + "self_addr": validate_connect_address, + Optional("bind_addr"): validate_host_port_listen, + "partner_addrs": validate_host_port_list, + Optional("data_dir"): str, + Optional("password"): str + }, + "zookeeper": { + "hosts": Or(comma_separated_host_port, [validate_host_port]), + }, + "kubernetes": { + "labels": {}, + Optional("namespace"): str, + Optional("scope_label"): str, + Optional("role_label"): str, + Optional("use_endpoints"): bool, + Optional("pod_ip"): Or(is_ipv4_address, is_ipv6_address), + Optional("ports"): [{"name": str, "port": int}], + }, + }), + "postgresql": { + "listen": validate_host_port_listen_multiple_hosts, + "connect_address": validate_connect_address, + "authentication": { + "replication": userattributes, + "superuser": userattributes, + "rewind": userattributes + }, + "data_dir": validate_data_dir, + Optional("bin_dir"): Directory(contains_executable=["gs_ctl", "gs_initdb", "pg_controldata", "gs_basebackup", + "gaussdb", "gs_isready"]), + Optional("parameters"): { + Optional("unix_socket_directories"): lambda s: assert_(all([isinstance(s, string_types), len(s)])) + }, + Optional("pg_hba"): [str], + Optional("pg_ident"): [str], + Optional("pg_ctl_timeout"): int, + Optional("use_pg_rewind"): bool + }, + Optional("watchdog"): { + Optional("mode"): lambda m: assert_(m in ["off", "automatic", "required"]), + Optional("device"): str + }, + Optional("tags"): { + Optional("nofailover"): bool, + Optional("clonefrom"): bool, + Optional("noloadbalance"): bool, + Optional("replicatefrom"): str, + Optional("nosync"): bool + } +}) diff --git a/patroni-for-openGauss/version.py b/patroni-for-openGauss/version.py new file mode 100644 index 0000000000000000000000000000000000000000..668c3446ee12c61dc54e5f9cacdd3b778505118c --- /dev/null +++ b/patroni-for-openGauss/version.py @@ -0,0 +1 @@ +__version__ = '2.0.2' diff --git a/patroni-for-openGauss/watchdog/__init__.py b/patroni-for-openGauss/watchdog/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4acc855e8846465877a8862da7d88477320f6abc --- /dev/null +++ b/patroni-for-openGauss/watchdog/__init__.py @@ -0,0 +1,2 @@ +from patroni.watchdog.base import WatchdogError, Watchdog +__all__ = ['WatchdogError', 'Watchdog'] diff --git a/patroni-for-openGauss/watchdog/base.py b/patroni-for-openGauss/watchdog/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5c155101ad0870ae02477adbe2afa5e3452c87a9 --- /dev/null +++ b/patroni-for-openGauss/watchdog/base.py @@ -0,0 +1,313 @@ +import abc +import logging +import platform +import six +import sys +from threading import RLock + +from patroni.exceptions import WatchdogError + +__all__ = ['WatchdogError', 'Watchdog'] + +logger = logging.getLogger(__name__) + +MODE_REQUIRED = 'required' # Will not run if a watchdog is not available +MODE_AUTOMATIC = 'automatic' # Will use a watchdog if one is available +MODE_OFF = 'off' # Will not try to use a watchdog + + +def parse_mode(mode): + if mode is False: + return MODE_OFF + mode = mode.lower() + if mode in ['require', 'required']: + return MODE_REQUIRED + elif mode in ['auto', 'automatic']: + return MODE_AUTOMATIC + else: + if mode not in ['off', 'disable', 'disabled']: + logger.warning("Watchdog mode {0} not recognized, disabling watchdog".format(mode)) + return MODE_OFF + + +def synchronized(func): + def wrapped(self, *args, **kwargs): + with self._lock: + return func(self, *args, **kwargs) + return wrapped + + +class WatchdogConfig(object): + """Helper to contain a snapshot of configuration""" + def __init__(self, config): + self.mode = parse_mode(config['watchdog'].get('mode', 'automatic')) + self.ttl = config['ttl'] + self.loop_wait = config['loop_wait'] + self.safety_margin = config['watchdog'].get('safety_margin', 5) + self.driver = config['watchdog'].get('driver', 'default') + self.driver_config = dict((k, v) for k, v in config['watchdog'].items() + if k not in ['mode', 'safety_margin', 'driver']) + + def __eq__(self, other): + return isinstance(other, WatchdogConfig) and \ + all(getattr(self, attr) == getattr(other, attr) for attr in + ['mode', 'ttl', 'loop_wait', 'safety_margin', 'driver', 'driver_config']) + + def __ne__(self, other): + return not self == other + + def get_impl(self): + if self.driver == 'testing': # pragma: no cover + from patroni.watchdog.linux import TestingWatchdogDevice + return TestingWatchdogDevice.from_config(self.driver_config) + elif platform.system() == 'Linux' and self.driver == 'default': + from patroni.watchdog.linux import LinuxWatchdogDevice + return LinuxWatchdogDevice.from_config(self.driver_config) + else: + return NullWatchdog() + + @property + def timeout(self): + if self.safety_margin == -1: + return int(self.ttl // 2) + else: + return self.ttl - self.safety_margin + + @property + def timing_slack(self): + return self.timeout - self.loop_wait + + +class Watchdog(object): + """Facade to dynamically manage watchdog implementations and handle config changes. + + When activation fails underlying implementation will be switched to a Null implementation. To avoid log spam + activation will only be retried when watchdog configuration is changed.""" + def __init__(self, config): + self.active_config = self.config = WatchdogConfig(config) + self._lock = RLock() + self.active = False + + if self.config.mode == MODE_OFF: + self.impl = NullWatchdog() + else: + self.impl = self.config.get_impl() + if self.config.mode == MODE_REQUIRED and self.impl.is_null: + logger.error("Configuration requires a watchdog, but watchdog is not supported on this platform.") + sys.exit(1) + + @synchronized + def reload_config(self, config): + self.config = WatchdogConfig(config) + # Turning a watchdog off can always be done immediately + if self.config.mode == MODE_OFF: + if self.active: + self._disable() + self.active_config = self.config + self.impl = NullWatchdog() + # If watchdog is not active we can apply config immediately to show any warnings early. Otherwise we need to + # delay until next time a keepalive is sent so timeout matches up with leader key update. + if not self.active: + if self.config.driver != self.active_config.driver or \ + self.config.driver_config != self.active_config.driver_config: + self.impl = self.config.get_impl() + self.active_config = self.config + + @synchronized + def activate(self): + """Activates the watchdog device with suitable timeouts. While watchdog is active keepalive needs + to be called every time loop_wait expires. + + :returns False if a safe watchdog could not be configured, but is required. + """ + self.active = True + return self._activate() + + def _activate(self): + self.active_config = self.config + + if self.config.timing_slack < 0: + logger.warning('Watchdog not supported because leader TTL {0} is less than 2x loop_wait {1}' + .format(self.config.ttl, self.config.loop_wait)) + self.impl = NullWatchdog() + + try: + self.impl.open() + actual_timeout = self._set_timeout() + except WatchdogError as e: + logger.warning("Could not activate %s: %s", self.impl.describe(), e) + self.impl = NullWatchdog() + + if self.impl.is_running and not self.impl.can_be_disabled: + logger.warning("Watchdog implementation can't be disabled." + " Watchdog will trigger after Patroni loses leader key.") + + if not self.impl.is_running or actual_timeout > self.config.timeout: + if self.config.mode == MODE_REQUIRED: + if self.impl.is_null: + logger.error("Configuration requires watchdog, but watchdog could not be configured.") + else: + logger.error("Configuration requires watchdog, but a safe watchdog timeout {0} could" + " not be configured. Watchdog timeout is {1}.".format( + self.config.timeout, actual_timeout)) + return False + else: + if not self.impl.is_null: + logger.warning("Watchdog timeout {0} seconds does not ensure safe termination within {1} seconds" + .format(actual_timeout, self.config.timeout)) + + if self.is_running: + logger.info("{0} activated with {1} second timeout, timing slack {2} seconds" + .format(self.impl.describe(), actual_timeout, self.config.timing_slack)) + else: + if self.config.mode == MODE_REQUIRED: + logger.error("Configuration requires watchdog, but watchdog could not be activated") + return False + + return True + + def _set_timeout(self): + if self.impl.has_set_timeout(): + self.impl.set_timeout(self.config.timeout) + + # Safety checks for watchdog implementations that don't support configurable timeouts + actual_timeout = self.impl.get_timeout() + if self.impl.is_running and actual_timeout < self.config.loop_wait: + logger.error('loop_wait of {0} seconds is too long for watchdog {1} second timeout' + .format(self.config.loop_wait, actual_timeout)) + if self.impl.can_be_disabled: + logger.info('Disabling watchdog due to unsafe timeout.') + self.impl.close() + self.impl = NullWatchdog() + return None + return actual_timeout + + @synchronized + def disable(self): + self._disable() + self.active = False + + def _disable(self): + try: + if self.impl.is_running and not self.impl.can_be_disabled: + # Give sysadmin some extra time to clean stuff up. + self.impl.keepalive() + logger.warning("Watchdog implementation can't be disabled. System will reboot after " + "{0} seconds when watchdog times out.".format(self.impl.get_timeout())) + self.impl.close() + except WatchdogError as e: + logger.error("Error while disabling watchdog: %s", e) + + @synchronized + def keepalive(self): + try: + if self.active: + self.impl.keepalive() + # In case there are any pending configuration changes apply them now. + if self.active and self.config != self.active_config: + if self.config.mode != MODE_OFF and self.active_config.mode == MODE_OFF: + self.impl = self.config.get_impl() + self._activate() + if self.config.driver != self.active_config.driver \ + or self.config.driver_config != self.active_config.driver_config: + self._disable() + self.impl = self.config.get_impl() + self._activate() + if self.config.timeout != self.active_config.timeout: + self.impl.set_timeout(self.config.timeout) + except WatchdogError as e: + logger.error("Error while sending keepalive: %s", e) + + @property + @synchronized + def is_running(self): + return self.impl.is_running + + @property + @synchronized + def is_healthy(self): + if self.config.mode != MODE_REQUIRED: + return True + return self.config.timing_slack >= 0 and self.impl.is_healthy + + +@six.add_metaclass(abc.ABCMeta) +class WatchdogBase(object): + """A watchdog object when opened requires periodic calls to keepalive. + When keepalive is not called within a timeout the system will be terminated.""" + is_null = False + + @property + def is_running(self): + """Returns True when watchdog is activated and capable of performing it's task.""" + return False + + @property + def is_healthy(self): + """Returns False when calling open() is known to fail.""" + return False + + @property + def can_be_disabled(self): + """Returns True when watchdog will be disabled by calling close(). Some watchdog devices + will keep running no matter what once activated. May raise WatchdogError if called without + calling open() first.""" + return True + + @abc.abstractmethod + def open(self): + """Open watchdog device. + + When watchdog is opened keepalive must be called. Returns nothing on success + or raises WatchdogError if the device could not be opened.""" + + @abc.abstractmethod + def close(self): + """Gracefully close watchdog device.""" + + @abc.abstractmethod + def keepalive(self): + """Resets the watchdog timer. + + Watchdog must be open when keepalive is called.""" + + @abc.abstractmethod + def get_timeout(self): + """Returns the current keepalive timeout in effect.""" + + @staticmethod + def has_set_timeout(): + """Returns True if setting a timeout is supported.""" + return False + + def set_timeout(self, timeout): + """Set the watchdog timer timeout. + + :param timeout: watchdog timeout in seconds""" + raise WatchdogError("Setting timeout is not supported on {0}".format(self.describe())) + + def describe(self): + """Human readable name for this device""" + return self.__class__.__name__ + + @classmethod + def from_config(cls, config): + return cls() + + +class NullWatchdog(WatchdogBase): + """Null implementation when watchdog is not supported.""" + is_null = True + + def open(self): + return + + def close(self): + return + + def keepalive(self): + return + + def get_timeout(self): + # A big enough number to not matter + return 1000000000 diff --git a/patroni-for-openGauss/watchdog/linux.py b/patroni-for-openGauss/watchdog/linux.py new file mode 100644 index 0000000000000000000000000000000000000000..ce80b5c1859f2951aa16eea02ea2a5be8c041729 --- /dev/null +++ b/patroni-for-openGauss/watchdog/linux.py @@ -0,0 +1,235 @@ +import collections +import ctypes +import os +import platform +from patroni.watchdog.base import WatchdogBase, WatchdogError + +# Pythonification of linux/ioctl.h +IOC_NONE = 0 +IOC_WRITE = 1 +IOC_READ = 2 + +IOC_NRBITS = 8 +IOC_TYPEBITS = 8 +IOC_SIZEBITS = 14 +IOC_DIRBITS = 2 + +# Non-generic platform special cases +machine = platform.machine() +if machine in ['mips', 'sparc', 'powerpc', 'ppc64']: # pragma: no cover + IOC_SIZEBITS = 13 + IOC_DIRBITS = 3 + IOC_NONE, IOC_WRITE, IOC_READ = 1, 2, 4 +elif machine == 'parisc': # pragma: no cover + IOC_WRITE, IOC_READ = 2, 1 + +IOC_NRSHIFT = 0 +IOC_TYPESHIFT = IOC_NRSHIFT + IOC_NRBITS +IOC_SIZESHIFT = IOC_TYPESHIFT + IOC_TYPEBITS +IOC_DIRSHIFT = IOC_SIZESHIFT + IOC_SIZEBITS + + +def IOW(type_, nr, size): + return IOC(IOC_WRITE, type_, nr, size) + + +def IOR(type_, nr, size): + return IOC(IOC_READ, type_, nr, size) + + +def IOWR(type_, nr, size): + return IOC(IOC_READ | IOC_WRITE, type_, nr, size) + + +def IOC(dir_, type_, nr, size): + return (dir_ << IOC_DIRSHIFT) \ + | (ord(type_) << IOC_TYPESHIFT) \ + | (nr << IOC_NRSHIFT) \ + | (size << IOC_SIZESHIFT) + + +# Pythonification of linux/watchdog.h + +WATCHDOG_IOCTL_BASE = 'W' + + +class watchdog_info(ctypes.Structure): + _fields_ = [ + ('options', ctypes.c_uint32), # Options the card/driver supports + ('firmware_version', ctypes.c_uint32), # Firmware version of the card + ('identity', ctypes.c_uint8 * 32), # Identity of the board + ] + + +struct_watchdog_info_size = ctypes.sizeof(watchdog_info) +int_size = ctypes.sizeof(ctypes.c_int) + +WDIOC_GETSUPPORT = IOR(WATCHDOG_IOCTL_BASE, 0, struct_watchdog_info_size) +WDIOC_GETSTATUS = IOR(WATCHDOG_IOCTL_BASE, 1, int_size) +WDIOC_GETBOOTSTATUS = IOR(WATCHDOG_IOCTL_BASE, 2, int_size) +WDIOC_GETTEMP = IOR(WATCHDOG_IOCTL_BASE, 3, int_size) +WDIOC_SETOPTIONS = IOR(WATCHDOG_IOCTL_BASE, 4, int_size) +WDIOC_KEEPALIVE = IOR(WATCHDOG_IOCTL_BASE, 5, int_size) +WDIOC_SETTIMEOUT = IOWR(WATCHDOG_IOCTL_BASE, 6, int_size) +WDIOC_GETTIMEOUT = IOR(WATCHDOG_IOCTL_BASE, 7, int_size) +WDIOC_SETPRETIMEOUT = IOWR(WATCHDOG_IOCTL_BASE, 8, int_size) +WDIOC_GETPRETIMEOUT = IOR(WATCHDOG_IOCTL_BASE, 9, int_size) +WDIOC_GETTIMELEFT = IOR(WATCHDOG_IOCTL_BASE, 10, int_size) + + +WDIOF_UNKNOWN = -1 # Unknown flag error +WDIOS_UNKNOWN = -1 # Unknown status error + +WDIOF = { + "OVERHEAT": 0x0001, # Reset due to CPU overheat + "FANFAULT": 0x0002, # Fan failed + "EXTERN1": 0x0004, # External relay 1 + "EXTERN2": 0x0008, # External relay 2 + "POWERUNDER": 0x0010, # Power bad/power fault + "CARDRESET": 0x0020, # Card previously reset the CPU + "POWEROVER": 0x0040, # Power over voltage + "SETTIMEOUT": 0x0080, # Set timeout (in seconds) + "MAGICCLOSE": 0x0100, # Supports magic close char + "PRETIMEOUT": 0x0200, # Pretimeout (in seconds), get/set + "ALARMONLY": 0x0400, # Watchdog triggers a management or other external alarm not a reboot + "KEEPALIVEPING": 0x8000, # Keep alive ping reply +} + +WDIOS = { + "DISABLECARD": 0x0001, # Turn off the watchdog timer + "ENABLECARD": 0x0002, # Turn on the watchdog timer + "TEMPPANIC": 0x0004, # Kernel panic on temperature trip +} + +# Implementation + + +class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,identity')): + """Watchdog descriptor from the kernel""" + def __getattr__(self, name): + """Convenience has_XYZ attributes for checking WDIOF bits in options""" + if name.startswith('has_') and name[4:] in WDIOF: + return bool(self.options & WDIOF[name[4:]]) + + raise AttributeError("WatchdogInfo instance has no attribute '{0}'".format(name)) + + +class LinuxWatchdogDevice(WatchdogBase): + DEFAULT_DEVICE = '/dev/watchdog' + + def __init__(self, device): + self.device = device + self._support_cache = None + self._fd = None + + @classmethod + def from_config(cls, config): + device = config.get('device', cls.DEFAULT_DEVICE) + return cls(device) + + @property + def is_running(self): + return self._fd is not None + + @property + def is_healthy(self): + return os.path.exists(self.device) and os.access(self.device, os.W_OK) + + def open(self): + try: + self._fd = os.open(self.device, os.O_WRONLY) + except OSError as e: + raise WatchdogError("Can't open watchdog device: {0}".format(e)) + + def close(self): + if self.is_running: + try: + os.write(self._fd, b'V') + os.close(self._fd) + self._fd = None + except OSError as e: + raise WatchdogError("Error while closing {0}: {1}".format(self.describe(), e)) + + @property + def can_be_disabled(self): + return self.get_support().has_MAGICCLOSE + + def _ioctl(self, func, arg): + """Runs the specified ioctl on the underlying fd. + + Raises WatchdogError if the device is closed. + Raises OSError or IOError (Python 2) when the ioctl fails.""" + if self._fd is None: + raise WatchdogError("Watchdog device is closed") + if os.name != 'nt': + import fcntl + fcntl.ioctl(self._fd, func, arg, True) + + def get_support(self): + if self._support_cache is None: + info = watchdog_info() + try: + self._ioctl(WDIOC_GETSUPPORT, info) + except (WatchdogError, OSError, IOError) as e: + raise WatchdogError("Could not get information about watchdog device: {}".format(e)) + self._support_cache = WatchdogInfo(info.options, + info.firmware_version, + bytearray(info.identity).decode(errors='ignore').rstrip('\x00')) + return self._support_cache + + def describe(self): + dev_str = " at {0}".format(self.device) if self.device != self.DEFAULT_DEVICE else "" + ver_str = "" + identity = "Linux watchdog device" + if self._fd: + try: + _, version, identity = self.get_support() + ver_str = " (firmware {0})".format(version) if version else "" + except WatchdogError: + pass + + return identity + ver_str + dev_str + + def keepalive(self): + try: + os.write(self._fd, b'1') + except OSError as e: + raise WatchdogError("Could not send watchdog keepalive: {0}".format(e)) + + def has_set_timeout(self): + """Returns True if setting a timeout is supported.""" + return self.get_support().has_SETTIMEOUT + + def set_timeout(self, timeout): + timeout = int(timeout) + if not 0 < timeout < 0xFFFF: + raise WatchdogError("Invalid timeout {0}. Supported values are between 1 and 65535".format(timeout)) + try: + self._ioctl(WDIOC_SETTIMEOUT, ctypes.c_int(timeout)) + except (WatchdogError, OSError, IOError) as e: + raise WatchdogError("Could not set timeout on watchdog device: {}".format(e)) + + def get_timeout(self): + timeout = ctypes.c_int() + try: + self._ioctl(WDIOC_GETTIMEOUT, timeout) + except (WatchdogError, OSError, IOError) as e: + raise WatchdogError("Could not get timeout on watchdog device: {}".format(e)) + return timeout.value + + +class TestingWatchdogDevice(LinuxWatchdogDevice): # pragma: no cover + """Converts timeout ioctls to regular writes that can be intercepted from a named pipe.""" + timeout = 60 + + def get_support(self): + return WatchdogInfo(WDIOF['MAGICCLOSE'] | WDIOF['SETTIMEOUT'], 0, "Watchdog test harness") + + def set_timeout(self, timeout): + buf = "Ctimeout={0}\n".format(timeout).encode('utf8') + while len(buf): + buf = buf[os.write(self._fd, buf):] + self.timeout = timeout + + def get_timeout(self): + return self.timeout