diff --git a/integration/copy_data/.gitignore b/integration/copy_data/.gitignore new file mode 100644 index 000000000..b3b7a146e --- /dev/null +++ b/integration/copy_data/.gitignore @@ -0,0 +1 @@ +*.bak.toml diff --git a/integration/copy_data/connect.sh b/integration/copy_data/connect.sh new file mode 100755 index 000000000..e84f967fc --- /dev/null +++ b/integration/copy_data/connect.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +case "$1" in + source) + PGPASSWORD=pgdog psql -h 127.0.0.1 -p 15432 -U pgdog -d pgdog + ;; + shard_0|0) + PGPASSWORD=pgdog psql -h 127.0.0.1 -p 15433 -U pgdog -d pgdog1 + ;; + shard_1|1) + PGPASSWORD=pgdog psql -h 127.0.0.1 -p 15434 -U pgdog -d pgdog2 + ;; + *) + echo "Usage: $0 {source|shard_0|0|shard_1|1}" + exit 1 + ;; +esac diff --git a/integration/copy_data/dev.sh b/integration/copy_data/dev.sh index 6486d4711..fdcffabf1 100644 --- a/integration/copy_data/dev.sh +++ b/integration/copy_data/dev.sh @@ -1,7 +1,8 @@ #!/bin/bash set -e +trap 'kill 0' SIGINT SIGTERM SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -DEFAULT_BIN="${SCRIPT_DIR}/../../target/release/pgdog" +DEFAULT_BIN="${SCRIPT_DIR}/../../target/debug/pgdog" PGDOG_BIN=${PGDOG_BIN:-$DEFAULT_BIN} export PGUSER=pgdog diff --git a/integration/copy_data/docker-compose.yml b/integration/copy_data/docker-compose.yml new file mode 100644 index 000000000..d7cdaf10a --- /dev/null +++ b/integration/copy_data/docker-compose.yml @@ -0,0 +1,41 @@ +services: + source: + image: postgres:18 + command: postgres -c wal_level=logical + environment: + POSTGRES_USER: pgdog + POSTGRES_PASSWORD: pgdog + POSTGRES_DB: pgdog + volumes: + - ./setup.sql:/docker-entrypoint-initdb.d/setup.sql + ports: + - 15432:5432 + networks: + - postgres + + shard_0: + image: postgres:18 + command: postgres -c wal_level=logical + environment: + POSTGRES_USER: pgdog + POSTGRES_PASSWORD: pgdog + POSTGRES_DB: pgdog1 + ports: + - 15433:5432 + networks: + - postgres + + shard_1: + image: postgres:18 + command: postgres -c wal_level=logical + environment: + POSTGRES_USER: pgdog + POSTGRES_PASSWORD: pgdog + POSTGRES_DB: pgdog2 + ports: + - 15434:5432 + networks: + - postgres + +networks: + postgres: diff --git a/integration/copy_data/init.sql b/integration/copy_data/init.sql index 5e8dfb34f..9fdd1fb33 100644 --- a/integration/copy_data/init.sql +++ b/integration/copy_data/init.sql @@ -3,6 +3,6 @@ DROP SCHEMA IF EXISTS copy_data CASCADE; \c pgdog2 DROP SCHEMA IF EXISTS copy_data CASCADE; \c pgdog -DROP SCHEMA IF EXISTS copy_data CASCADE; +-- DROP SCHEMA IF EXISTS copy_data CASCADE; SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots; \i setup.sql diff --git a/integration/copy_data/loader/loader.py b/integration/copy_data/loader/loader.py new file mode 100755 index 000000000..180da8eb0 --- /dev/null +++ b/integration/copy_data/loader/loader.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +""" +Fast data loader for copy_data schema using PostgreSQL COPY protocol. +Generates ~50-100GB of test data. +""" + +import argparse +import json +import random +import string +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from datetime import datetime, timedelta, timezone +from typing import Generator + +import numpy as np +import psycopg + + +def random_text(length: int) -> str: + return "".join(random.choices(string.ascii_uppercase, k=length)) + + +def random_texts_numpy(count: int, length: int) -> list[str]: + """Generate random texts using numpy for speed.""" + chars = np.random.randint(65, 91, size=(count, length), dtype=np.uint8) + return ["".join(chr(c) for c in row) for row in chars] + + +def random_timestamps(count: int, days_back: int = 730) -> list[str]: + """Generate random timestamps within the last N days.""" + now = datetime.now(timezone.utc) + offsets = np.random.uniform(0, days_back * 86400, count) + return [(now - timedelta(seconds=float(off))).isoformat() for off in offsets] + + +def format_duration(seconds: float) -> str: + if seconds < 60: + return f"{seconds:.1f}s" + elif seconds < 3600: + return f"{seconds / 60:.1f}m" + else: + return f"{seconds / 3600:.1f}h" + + +def create_schema(conninfo: str): + """Create the schema and tables.""" + with psycopg.connect(conninfo) as conn: + conn.execute("CREATE SCHEMA IF NOT EXISTS copy_data") + + # Users table with partitions + conn.execute(""" + CREATE TABLE IF NOT EXISTS copy_data.users ( + id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL, + email VARCHAR NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + settings JSONB NOT NULL DEFAULT '{}'::jsonb, + PRIMARY KEY(id, tenant_id) + ) PARTITION BY HASH(tenant_id) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS copy_data.users_0 PARTITION OF copy_data.users + FOR VALUES WITH (MODULUS 2, REMAINDER 0) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS copy_data.users_1 PARTITION OF copy_data.users + FOR VALUES WITH (MODULUS 2, REMAINDER 1) + """) + + # Orders table + conn.execute("DROP TABLE IF EXISTS copy_data.order_items") + conn.execute("DROP TABLE IF EXISTS copy_data.orders") + conn.execute(""" + CREATE TABLE copy_data.orders ( + id BIGINT PRIMARY KEY, + user_id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL, + amount DOUBLE PRECISION NOT NULL DEFAULT 0.0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + refunded_at TIMESTAMPTZ, + notes TEXT + ) + """) + + # Order items table + conn.execute(""" + CREATE TABLE copy_data.order_items ( + id BIGINT PRIMARY KEY, + user_id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL, + order_id BIGINT NOT NULL, + product_name TEXT NOT NULL, + amount DOUBLE PRECISION NOT NULL DEFAULT 0.0, + quantity INT NOT NULL DEFAULT 1, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + refunded_at TIMESTAMPTZ + ) + """) + + # Log actions table + conn.execute("DROP TABLE IF EXISTS copy_data.log_actions") + conn.execute(""" + CREATE TABLE copy_data.log_actions ( + id BIGINT PRIMARY KEY, + tenant_id BIGINT, + user_id BIGINT, + action VARCHAR(50), + details TEXT, + ip_address VARCHAR(45), + user_agent TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """) + + # With identity table + conn.execute("DROP TABLE IF EXISTS copy_data.with_identity") + conn.execute(""" + CREATE TABLE copy_data.with_identity ( + id BIGINT GENERATED ALWAYS AS IDENTITY, + tenant_id BIGINT NOT NULL, + data TEXT + ) + """) + + conn.execute("TRUNCATE copy_data.users CASCADE") + conn.commit() + print("Schema created successfully") + + +def load_users(conninfo: str, total: int, batch_size: int = 100_000) -> dict: + """Load users table using COPY.""" + start = time.time() + loaded = 0 + + themes = ["light", "dark", "auto"] + languages = ["en", "es", "fr", "de", "ja", "zh"] + timezones = ["UTC", "America/New_York", "Europe/London", "Asia/Tokyo"] + + with psycopg.connect(conninfo) as conn: + with conn.cursor() as cur: + with cur.copy( + "COPY copy_data.users (id, tenant_id, email, created_at, settings) FROM STDIN" + ) as copy: + for batch_start in range(0, total, batch_size): + batch_end = min(batch_start + batch_size, total) + batch_count = batch_end - batch_start + + timestamps = random_timestamps(batch_count) + bios = random_texts_numpy(batch_count, 200) + addresses = random_texts_numpy(batch_count, 100) + phones = random_texts_numpy(batch_count, 20) + metadata = random_texts_numpy(batch_count, 100) + + for i, idx in enumerate(range(batch_start + 1, batch_end + 1)): + tenant_id = ((idx - 1) % 1000) + 1 + email = f"user_{idx}_tenant_{tenant_id}@example.com" + settings = json.dumps( + { + "theme": random.choice(themes), + "notifications": random.random() > 0.5, + "preferences": { + "language": random.choice(languages), + "timezone": random.choice(timezones), + "bio": bios[i], + "address": addresses[i], + "phone": phones[i], + "metadata": metadata[i], + }, + } + ) + copy.write_row((idx, tenant_id, email, timestamps[i], settings)) + + loaded = batch_end + elapsed = time.time() - start + rate = loaded / elapsed if elapsed > 0 else 0 + print( + f" users: {loaded:,}/{total:,} ({100*loaded/total:.1f}%) - {rate:,.0f} rows/s" + ) + + conn.commit() + + elapsed = time.time() - start + return {"table": "users", "rows": loaded, "elapsed": elapsed} + + +def load_orders(conninfo: str, total: int, batch_size: int = 100_000) -> dict: + """Load orders table using COPY.""" + start = time.time() + loaded = 0 + + with psycopg.connect(conninfo) as conn: + with conn.cursor() as cur: + with cur.copy( + "COPY copy_data.orders (id, user_id, tenant_id, amount, created_at, refunded_at, notes) FROM STDIN" + ) as copy: + for batch_start in range(0, total, batch_size): + batch_end = min(batch_start + batch_size, total) + batch_count = batch_end - batch_start + + user_ids = np.random.randint(1, 5_000_001, batch_count) + tenant_ids = np.random.randint(1, 1001, batch_count) + amounts = np.round(10 + np.random.random(batch_count) * 990, 2) + timestamps = random_timestamps(batch_count) + refund_flags = np.random.random(batch_count) < 0.05 + refund_times = random_timestamps(batch_count, 365) + notes = random_texts_numpy(batch_count, 50) + + for i, idx in enumerate(range(batch_start + 1, batch_end + 1)): + refunded_at = refund_times[i] if refund_flags[i] else None + copy.write_row( + ( + idx, + int(user_ids[i]), + int(tenant_ids[i]), + float(amounts[i]), + timestamps[i], + refunded_at, + notes[i], + ) + ) + + loaded = batch_end + elapsed = time.time() - start + rate = loaded / elapsed if elapsed > 0 else 0 + print( + f" orders: {loaded:,}/{total:,} ({100*loaded/total:.1f}%) - {rate:,.0f} rows/s" + ) + + conn.commit() + + elapsed = time.time() - start + return {"table": "orders", "rows": loaded, "elapsed": elapsed} + + +def load_order_items( + conninfo: str, total: int, max_order_id: int, batch_size: int = 100_000 +) -> dict: + """Load order_items table using COPY.""" + start = time.time() + loaded = 0 + + with psycopg.connect(conninfo) as conn: + with conn.cursor() as cur: + with cur.copy( + "COPY copy_data.order_items (id, user_id, tenant_id, order_id, product_name, amount, quantity, created_at, refunded_at) FROM STDIN" + ) as copy: + for batch_start in range(0, total, batch_size): + batch_end = min(batch_start + batch_size, total) + batch_count = batch_end - batch_start + + user_ids = np.random.randint(1, 5_000_001, batch_count) + tenant_ids = np.random.randint(1, 1001, batch_count) + order_ids = np.random.randint(1, max_order_id + 1, batch_count) + amounts = np.round(5 + np.random.random(batch_count) * 195, 2) + quantities = np.random.randint(1, 6, batch_count) + timestamps = random_timestamps(batch_count) + refund_flags = np.random.random(batch_count) < 0.05 + refund_times = random_timestamps(batch_count, 365) + product_names = random_texts_numpy(batch_count, 30) + + for i, idx in enumerate(range(batch_start + 1, batch_end + 1)): + refunded_at = refund_times[i] if refund_flags[i] else None + copy.write_row( + ( + idx, + int(user_ids[i]), + int(tenant_ids[i]), + int(order_ids[i]), + f"Product {product_names[i]}", + float(amounts[i]), + int(quantities[i]), + timestamps[i], + refunded_at, + ) + ) + + loaded = batch_end + elapsed = time.time() - start + rate = loaded / elapsed if elapsed > 0 else 0 + print( + f" order_items: {loaded:,}/{total:,} ({100*loaded/total:.1f}%) - {rate:,.0f} rows/s" + ) + + conn.commit() + + elapsed = time.time() - start + return {"table": "order_items", "rows": loaded, "elapsed": elapsed} + + +def load_log_actions(conninfo: str, total: int, batch_size: int = 500_000) -> dict: + """Load log_actions table using COPY.""" + start = time.time() + loaded = 0 + + actions = [ + "login", + "logout", + "click", + "purchase", + "view", + "error", + "search", + "update", + "delete", + "create", + ] + + with psycopg.connect(conninfo) as conn: + with conn.cursor() as cur: + with cur.copy( + "COPY copy_data.log_actions (id, tenant_id, user_id, action, details, ip_address, user_agent, created_at) FROM STDIN" + ) as copy: + for batch_start in range(0, total, batch_size): + batch_end = min(batch_start + batch_size, total) + batch_count = batch_end - batch_start + + tenant_null_flags = np.random.random(batch_count) < 0.1 + tenant_ids = np.random.randint(1, 1001, batch_count) + user_ids = np.random.randint(1, 5_000_001, batch_count) + action_indices = np.random.randint(0, len(actions), batch_count) + timestamps = random_timestamps(batch_count) + details = random_texts_numpy(batch_count, 50) + user_agents = random_texts_numpy(batch_count, 30) + ip_parts = np.random.randint(0, 256, (batch_count, 4)) + + for i, idx in enumerate(range(batch_start + 1, batch_end + 1)): + tenant_id = None if tenant_null_flags[i] else int(tenant_ids[i]) + ip = f"{ip_parts[i,0]}.{ip_parts[i,1]}.{ip_parts[i,2]}.{ip_parts[i,3]}" + copy.write_row( + ( + idx, + tenant_id, + int(user_ids[i]), + actions[action_indices[i]], + details[i], + ip, + f"Mozilla/5.0 {user_agents[i]}", + timestamps[i], + ) + ) + + loaded = batch_end + elapsed = time.time() - start + rate = loaded / elapsed if elapsed > 0 else 0 + print( + f" log_actions: {loaded:,}/{total:,} ({100*loaded/total:.1f}%) - {rate:,.0f} rows/s" + ) + + conn.commit() + + elapsed = time.time() - start + return {"table": "log_actions", "rows": loaded, "elapsed": elapsed} + + +def load_with_identity(conninfo: str, total: int, batch_size: int = 100_000) -> dict: + """Load with_identity table using COPY.""" + start = time.time() + loaded = 0 + + with psycopg.connect(conninfo) as conn: + with conn.cursor() as cur: + # For GENERATED ALWAYS AS IDENTITY, we use DEFAULT + with cur.copy( + "COPY copy_data.with_identity (tenant_id, data) FROM STDIN" + ) as copy: + for batch_start in range(0, total, batch_size): + batch_end = min(batch_start + batch_size, total) + batch_count = batch_end - batch_start + + tenant_ids = np.random.randint(1, 1001, batch_count) + data = random_texts_numpy(batch_count, 20) + + for i in range(batch_count): + copy.write_row((int(tenant_ids[i]), data[i])) + + loaded = batch_end + elapsed = time.time() - start + rate = loaded / elapsed if elapsed > 0 else 0 + print( + f" with_identity: {loaded:,}/{total:,} ({100*loaded/total:.1f}%) - {rate:,.0f} rows/s" + ) + + conn.commit() + + elapsed = time.time() - start + return {"table": "with_identity", "rows": loaded, "elapsed": elapsed} + + +def create_indexes(conninfo: str): + """Create indexes after data load.""" + print("Creating indexes...") + start = time.time() + + with psycopg.connect(conninfo) as conn: + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_orders_user_tenant ON copy_data.orders(user_id, tenant_id)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_order_items_order ON copy_data.order_items(order_id)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_log_actions_tenant ON copy_data.log_actions(tenant_id)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_log_actions_created ON copy_data.log_actions(created_at)" + ) + conn.commit() + + print(f"Indexes created in {format_duration(time.time() - start)}") + + +def analyze_tables(conninfo: str): + """Analyze tables for query planner.""" + print("Analyzing tables...") + start = time.time() + + with psycopg.connect(conninfo) as conn: + conn.execute("ANALYZE copy_data.users") + conn.execute("ANALYZE copy_data.orders") + conn.execute("ANALYZE copy_data.order_items") + conn.execute("ANALYZE copy_data.log_actions") + conn.execute("ANALYZE copy_data.with_identity") + conn.commit() + + print(f"Analyze completed in {format_duration(time.time() - start)}") + + +def show_table_sizes(conninfo: str): + """Show final table sizes.""" + with psycopg.connect(conninfo) as conn: + result = conn.execute(""" + SELECT + tablename, + pg_size_pretty(pg_total_relation_size('copy_data.' || tablename)) as total_size, + pg_size_pretty(pg_relation_size('copy_data.' || tablename)) as table_size + FROM pg_tables + WHERE schemaname = 'copy_data' + ORDER BY pg_total_relation_size('copy_data.' || tablename) DESC + """).fetchall() + + print("\nTable sizes:") + print("-" * 50) + for row in result: + print(f" {row[0]:20} {row[1]:>12} (data: {row[2]})") + + +def create_publication(conninfo: str): + """Create publication for replication.""" + with psycopg.connect(conninfo) as conn: + conn.execute("DROP PUBLICATION IF EXISTS pgdog") + conn.execute("CREATE PUBLICATION pgdog FOR TABLES IN SCHEMA copy_data") + conn.commit() + print("Publication 'pgdog' created") + + +def main(): + parser = argparse.ArgumentParser( + description="Fast data loader for copy_data schema" + ) + parser.add_argument( + "--conninfo", + default="host=localhost dbname=pgdog user=pgdog password=pgdog", + help="PostgreSQL connection string", + ) + parser.add_argument( + "--scale", + type=float, + default=1.0, + help="Scale factor (1.0 = ~50GB, 0.1 = ~5GB, 2.0 = ~100GB)", + ) + parser.add_argument( + "--parallel", + type=int, + default=4, + help="Number of parallel loaders (default: 4)", + ) + parser.add_argument( + "--skip-schema", action="store_true", help="Skip schema creation" + ) + parser.add_argument( + "--skip-indexes", action="store_true", help="Skip index creation" + ) + + args = parser.parse_args() + + # Calculate row counts based on scale + scale = args.scale + users_count = int(5_000_000 * scale) + orders_count = int(20_000_000 * scale) + order_items_count = int(60_000_000 * scale) + log_actions_count = int(300_000_000 * scale) + with_identity_count = int(50_000_000 * scale) + + print(f"Data loader configuration:") + print(f" Connection: {args.conninfo}") + print(f" Scale: {scale}x") + print(f" Parallel loaders: {args.parallel}") + print(f" Users: {users_count:,}") + print(f" Orders: {orders_count:,}") + print(f" Order items: {order_items_count:,}") + print(f" Log actions: {log_actions_count:,}") + print(f" With identity: {with_identity_count:,}") + print() + + total_start = time.time() + + # Create schema + if not args.skip_schema: + create_schema(args.conninfo) + + # Load tables in parallel + print("\nLoading data...") + results = [] + + # Users and orders can run in parallel + # Order items depends on orders (for valid order_ids) + # Log actions and with_identity are independent + + with ProcessPoolExecutor(max_workers=args.parallel) as executor: + futures = {} + + # Phase 1: users and orders in parallel + futures[executor.submit(load_users, args.conninfo, users_count)] = "users" + futures[executor.submit(load_orders, args.conninfo, orders_count)] = "orders" + + # Wait for orders to complete before starting order_items + for future in as_completed(futures): + result = future.result() + results.append(result) + print( + f" {result['table']} completed: {result['rows']:,} rows in {format_duration(result['elapsed'])}" + ) + + # Phase 2: order_items, log_actions, with_identity in parallel + futures = {} + futures[ + executor.submit( + load_order_items, args.conninfo, order_items_count, orders_count + ) + ] = "order_items" + futures[ + executor.submit(load_log_actions, args.conninfo, log_actions_count) + ] = "log_actions" + futures[ + executor.submit(load_with_identity, args.conninfo, with_identity_count) + ] = "with_identity" + + for future in as_completed(futures): + result = future.result() + results.append(result) + print( + f" {result['table']} completed: {result['rows']:,} rows in {format_duration(result['elapsed'])}" + ) + + # Create indexes + if not args.skip_indexes: + create_indexes(args.conninfo) + + # Analyze + analyze_tables(args.conninfo) + + # Show sizes + show_table_sizes(args.conninfo) + + # Create publication + create_publication(args.conninfo) + + total_elapsed = time.time() - total_start + total_rows = sum(r["rows"] for r in results) + print(f"\nTotal: {total_rows:,} rows loaded in {format_duration(total_elapsed)}") + print(f"Average rate: {total_rows / total_elapsed:,.0f} rows/s") + + +if __name__ == "__main__": + main() diff --git a/integration/copy_data/loader/requirements.txt b/integration/copy_data/loader/requirements.txt new file mode 100644 index 000000000..c7d450899 --- /dev/null +++ b/integration/copy_data/loader/requirements.txt @@ -0,0 +1,2 @@ +psycopg[binary]>=3.1 +numpy>=1.24 diff --git a/integration/copy_data/loader/run.sh b/integration/copy_data/loader/run.sh new file mode 100755 index 000000000..004721291 --- /dev/null +++ b/integration/copy_data/loader/run.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VENV_DIR="$SCRIPT_DIR/.venv" + +# Create virtual environment if it doesn't exist +if [ ! -d "$VENV_DIR" ]; then + echo "Creating virtual environment..." + python3 -m venv "$VENV_DIR" +fi + +# Activate virtual environment +source "$VENV_DIR/bin/activate" + +# Install/upgrade dependencies +echo "Installing dependencies..." +pip install --quiet --upgrade pip +pip install --quiet -r "$SCRIPT_DIR/requirements.txt" + +# Run the loader with all arguments passed through +echo "Running loader..." +echo "" +python "$SCRIPT_DIR/loader.py" "$@" diff --git a/integration/copy_data/pgdog.docker.toml b/integration/copy_data/pgdog.docker.toml new file mode 100644 index 000000000..bd3fcea17 --- /dev/null +++ b/integration/copy_data/pgdog.docker.toml @@ -0,0 +1,32 @@ +[general] +resharding_copy_format = "binary" +load_schema = "on" +cutover_timeout_action = "cutover" + +[[databases]] +name = "source" +host = "127.0.0.1" +port = 15432 +database_name = "pgdog" + +[[databases]] +name = "destination" +host = "127.0.0.1" +port = 15433 +database_name = "pgdog1" +shard = 0 + +[[databases]] +name = "destination" +host = "127.0.0.1" +port = 15434 +database_name = "pgdog2" +shard = 1 + +[[sharded_tables]] +database = "destination" +column = "tenant_id" +data_type = "bigint" + +[admin] +password = "pgdog" diff --git a/integration/copy_data/pgdog.toml b/integration/copy_data/pgdog.toml index e918cd26f..36e856411 100644 --- a/integration/copy_data/pgdog.toml +++ b/integration/copy_data/pgdog.toml @@ -22,3 +22,6 @@ shard = 1 database = "destination" column = "tenant_id" data_type = "bigint" + +[admin] +password = "pgdog" diff --git a/integration/copy_data/reset.sh b/integration/copy_data/reset.sh new file mode 100644 index 000000000..f23820519 --- /dev/null +++ b/integration/copy_data/reset.sh @@ -0,0 +1,6 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +pushd $SCRIPT_DIR +cp pgdog.bak.toml pgdog.toml +cp users.bak.toml users.toml +popd diff --git a/integration/copy_data/setup.sql b/integration/copy_data/setup.sql index b2de7dcac..fa6114e5e 100644 --- a/integration/copy_data/setup.sql +++ b/integration/copy_data/setup.sql @@ -15,25 +15,6 @@ CREATE TABLE IF NOT EXISTS copy_data.users_0 PARTITION OF copy_data.users CREATE TABLE IF NOT EXISTS copy_data.users_1 PARTITION OF copy_data.users FOR VALUES WITH (MODULUS 2, REMAINDER 1); -TRUNCATE TABLE copy_data.users; - -INSERT INTO copy_data.users (id, tenant_id, email, created_at, settings) -SELECT - gs.id, - ((gs.id - 1) % 20) + 1 AS tenant_id, -- distribute across 20 tenants - format('user_%s_tenant_%s@example.com', gs.id, ((gs.id - 1) % 20) + 1) AS email, - NOW() - (random() * interval '365 days') AS created_at, -- random past date - jsonb_build_object( - 'theme', CASE (random() * 3)::int - WHEN 0 THEN 'light' - WHEN 1 THEN 'dark' - ELSE 'auto' - END, - 'notifications', (random() > 0.5) - ) AS settings -FROM generate_series(1, 10000) AS gs(id); - -DROP TABLE copy_data.orders; CREATE TABLE IF NOT EXISTS copy_data.orders ( id BIGSERIAL PRIMARY KEY, user_id BIGINT NOT NULL, @@ -52,10 +33,38 @@ CREATE TABLE IF NOT EXISTS copy_data.order_items ( refunded_at TIMESTAMPTZ ); --- --- Fix/define schema (safe to run if you're starting fresh) --- --- Adjust/drop statements as needed if the tables already exist. -TRUNCATE TABLE copy_data.order_items CASCADE; -TRUNCATE TABLE copy_data.orders CASCADE; +CREATE TABLE IF NOT EXISTS copy_data.log_actions( + id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT, + action VARCHAR, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE copy_data.with_identity( + id BIGINT PRIMARY KEY GENERATED ALWAYS AS identity, + tenant_id BIGINT NOT NULL +); + +DROP PUBLICATION IF EXISTS pgdog; +CREATE PUBLICATION pgdog FOR TABLES IN SCHEMA copy_data; + + +INSERT INTO copy_data.users (id, tenant_id, email, created_at, settings) +SELECT + gs.id, + ((gs.id - 1) % 20) + 1 AS tenant_id, -- distribute across 20 tenants + format('user_%s_tenant_%s@example.com', gs.id, ((gs.id - 1) % 20) + 1) AS email, + NOW() - (random() * interval '365 days') AS created_at, -- random past date + jsonb_build_object( + 'theme', CASE (random() * 3)::int + WHEN 0 THEN 'light' + WHEN 1 THEN 'dark' + ELSE 'auto' + END, + 'notifications', (random() > 0.5) + ) AS settings +FROM generate_series(1, 10000) AS gs(id); + WITH u AS ( -- Pull the 10k users we inserted earlier @@ -134,13 +143,6 @@ SELECT ir.item_refunded_at FROM items_raw ir; -CREATE TABLE copy_data.log_actions( - id BIGSERIAL PRIMARY KEY, - tenant_id BIGINT, - action VARCHAR, - created_at TIMESTAMPTZ NOT NULL DEFAULT now() -); - INSERT INTO copy_data.log_actions (tenant_id, action) SELECT CASE WHEN random() < 0.2 THEN NULL ELSE (floor(random() * 10000) + 1)::bigint END AS tenant_id, @@ -149,13 +151,6 @@ SELECT ] AS action FROM generate_series(1, 10000); -CREATE TABLE copy_data.with_identity( - id BIGINT GENERATED ALWAYS AS identity, - tenant_id BIGINT NOT NULL -); INSERT INTO copy_data.with_identity (tenant_id) SELECT floor(random() * 10000)::bigint FROM generate_series(1, 10000); - -DROP PUBLICATION IF EXISTS pgdog; -CREATE PUBLICATION pgdog FOR TABLES IN SCHEMA copy_data; diff --git a/integration/copy_data/setup_large.sql b/integration/copy_data/setup_large.sql new file mode 100644 index 000000000..b13f99dc8 --- /dev/null +++ b/integration/copy_data/setup_large.sql @@ -0,0 +1,225 @@ +-- Large data setup for testing replication (generates ~50-100GB of data) +-- WARNING: This will take a long time to run and use significant disk space + +CREATE SCHEMA IF NOT EXISTS copy_data; + +-- Helper function to generate random text +CREATE OR REPLACE FUNCTION copy_data.random_text(len INT) RETURNS TEXT AS $$ +SELECT string_agg(chr(65 + (random() * 25)::int), '') +FROM generate_series(1, len); +$$ LANGUAGE SQL; + +-- Users table: 5 million rows with ~500 byte JSONB = ~2.5GB +CREATE TABLE IF NOT EXISTS copy_data.users ( + id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL, + email VARCHAR NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + settings JSONB NOT NULL DEFAULT '{}'::jsonb, + PRIMARY KEY(id, tenant_id) +) PARTITION BY HASH(tenant_id); + +CREATE TABLE IF NOT EXISTS copy_data.users_0 PARTITION OF copy_data.users + FOR VALUES WITH (MODULUS 2, REMAINDER 0); + +CREATE TABLE IF NOT EXISTS copy_data.users_1 PARTITION OF copy_data.users + FOR VALUES WITH (MODULUS 2, REMAINDER 1); + +TRUNCATE TABLE copy_data.users; + +-- Insert users in batches of 100k to avoid memory issues +DO $$ +DECLARE + batch_size INT := 100000; + total_users INT := 5000000; + i INT := 0; +BEGIN + WHILE i < total_users LOOP + INSERT INTO copy_data.users (id, tenant_id, email, created_at, settings) + SELECT + gs.id, + ((gs.id - 1) % 1000) + 1 AS tenant_id, + format('user_%s_tenant_%s@example.com', gs.id, ((gs.id - 1) % 1000) + 1) AS email, + NOW() - (random() * interval '730 days') AS created_at, + jsonb_build_object( + 'theme', CASE (random() * 3)::int WHEN 0 THEN 'light' WHEN 1 THEN 'dark' ELSE 'auto' END, + 'notifications', (random() > 0.5), + 'preferences', jsonb_build_object( + 'language', (ARRAY['en', 'es', 'fr', 'de', 'ja', 'zh'])[floor(random() * 6 + 1)::int], + 'timezone', (ARRAY['UTC', 'America/New_York', 'Europe/London', 'Asia/Tokyo'])[floor(random() * 4 + 1)::int], + 'bio', copy_data.random_text(200), + 'address', copy_data.random_text(100), + 'phone', copy_data.random_text(20), + 'metadata', copy_data.random_text(100) + ) + ) AS settings + FROM generate_series(i + 1, LEAST(i + batch_size, total_users)) AS gs(id); + + i := i + batch_size; + RAISE NOTICE 'Inserted % users', i; + END LOOP; +END $$; + +-- Orders table: 20 million rows ~1.5GB +DROP TABLE IF EXISTS copy_data.order_items; +DROP TABLE IF EXISTS copy_data.orders; + +CREATE TABLE copy_data.orders ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL, + amount DOUBLE PRECISION NOT NULL DEFAULT 0.0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + refunded_at TIMESTAMPTZ, + notes TEXT +); + +CREATE TABLE copy_data.order_items ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL, + order_id BIGINT NOT NULL, + product_name TEXT NOT NULL, + amount DOUBLE PRECISION NOT NULL DEFAULT 0.0, + quantity INT NOT NULL DEFAULT 1, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + refunded_at TIMESTAMPTZ +); + +-- Insert orders in batches +DO $$ +DECLARE + batch_size INT := 100000; + total_orders INT := 20000000; + i INT := 0; +BEGIN + WHILE i < total_orders LOOP + INSERT INTO copy_data.orders (user_id, tenant_id, amount, created_at, refunded_at, notes) + SELECT + (floor(random() * 5000000) + 1)::bigint AS user_id, + (floor(random() * 1000) + 1)::bigint AS tenant_id, + ROUND((10 + random() * 990)::numeric, 2)::float8 AS amount, + NOW() - (random() * INTERVAL '730 days') AS created_at, + CASE WHEN random() < 0.05 THEN NOW() - (random() * INTERVAL '365 days') ELSE NULL END AS refunded_at, + copy_data.random_text(50) AS notes + FROM generate_series(1, batch_size); + + i := i + batch_size; + RAISE NOTICE 'Inserted % orders', i; + END LOOP; +END $$; + +-- Insert order_items: ~60 million rows (3 per order avg) ~6GB with product names +DO $$ +DECLARE + batch_size INT := 100000; + total_items INT := 60000000; + i INT := 0; +BEGIN + WHILE i < total_items LOOP + INSERT INTO copy_data.order_items (user_id, tenant_id, order_id, product_name, amount, quantity, created_at, refunded_at) + SELECT + (floor(random() * 5000000) + 1)::bigint AS user_id, + (floor(random() * 1000) + 1)::bigint AS tenant_id, + (floor(random() * 20000000) + 1)::bigint AS order_id, + 'Product ' || copy_data.random_text(30) AS product_name, + ROUND((5 + random() * 195)::numeric, 2)::float8 AS amount, + (floor(random() * 5) + 1)::int AS quantity, + NOW() - (random() * INTERVAL '730 days') AS created_at, + CASE WHEN random() < 0.05 THEN NOW() - (random() * INTERVAL '365 days') ELSE NULL END AS refunded_at + FROM generate_series(1, batch_size); + + i := i + batch_size; + RAISE NOTICE 'Inserted % order_items', i; + END LOOP; +END $$; + +-- Log actions: 300 million rows with larger action data ~30GB +DROP TABLE IF EXISTS copy_data.log_actions; +CREATE TABLE copy_data.log_actions( + id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT, + user_id BIGINT, + action VARCHAR(50), + details TEXT, + ip_address VARCHAR(45), + user_agent TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +DO $$ +DECLARE + batch_size INT := 500000; + total_logs INT := 300000000; + i INT := 0; +BEGIN + WHILE i < total_logs LOOP + INSERT INTO copy_data.log_actions (tenant_id, user_id, action, details, ip_address, user_agent, created_at) + SELECT + CASE WHEN random() < 0.1 THEN NULL ELSE (floor(random() * 1000) + 1)::bigint END AS tenant_id, + (floor(random() * 5000000) + 1)::bigint AS user_id, + (ARRAY['login', 'logout', 'click', 'purchase', 'view', 'error', 'search', 'update', 'delete', 'create'])[ + floor(random() * 10 + 1)::int + ] AS action, + copy_data.random_text(50) AS details, + format('%s.%s.%s.%s', (random()*255)::int, (random()*255)::int, (random()*255)::int, (random()*255)::int) AS ip_address, + 'Mozilla/5.0 ' || copy_data.random_text(30) AS user_agent, + NOW() - (random() * INTERVAL '730 days') AS created_at + FROM generate_series(1, batch_size); + + i := i + batch_size; + RAISE NOTICE 'Inserted % log_actions', i; + END LOOP; +END $$; + +-- With identity: 50 million rows ~2GB +DROP TABLE IF EXISTS copy_data.with_identity; +CREATE TABLE copy_data.with_identity( + id BIGINT GENERATED ALWAYS AS IDENTITY, + tenant_id BIGINT NOT NULL, + data TEXT +); + +DO $$ +DECLARE + batch_size INT := 100000; + total_rows INT := 50000000; + i INT := 0; +BEGIN + WHILE i < total_rows LOOP + INSERT INTO copy_data.with_identity (tenant_id, data) + SELECT + (floor(random() * 1000) + 1)::bigint, + copy_data.random_text(20) + FROM generate_series(1, batch_size); + + i := i + batch_size; + RAISE NOTICE 'Inserted % with_identity rows', i; + END LOOP; +END $$; + +-- Create indexes after data load for better performance +CREATE INDEX IF NOT EXISTS idx_orders_user_tenant ON copy_data.orders(user_id, tenant_id); +CREATE INDEX IF NOT EXISTS idx_order_items_order ON copy_data.order_items(order_id); +CREATE INDEX IF NOT EXISTS idx_log_actions_tenant ON copy_data.log_actions(tenant_id); +CREATE INDEX IF NOT EXISTS idx_log_actions_created ON copy_data.log_actions(created_at); + +-- Analyze tables for query planner +ANALYZE copy_data.users; +ANALYZE copy_data.orders; +ANALYZE copy_data.order_items; +ANALYZE copy_data.log_actions; +ANALYZE copy_data.with_identity; + +-- Show table sizes +SELECT + schemaname, + tablename, + pg_size_pretty(pg_total_relation_size(schemaname || '.' || tablename)) as total_size, + pg_size_pretty(pg_relation_size(schemaname || '.' || tablename)) as table_size +FROM pg_tables +WHERE schemaname = 'copy_data' +ORDER BY pg_total_relation_size(schemaname || '.' || tablename) DESC; + +DROP PUBLICATION IF EXISTS pgdog; +CREATE PUBLICATION pgdog FOR TABLES IN SCHEMA copy_data; diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index cb7fd7b09..3d5d8ed15 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -5,6 +5,7 @@ use std::path::PathBuf; use tracing::{info, warn}; use crate::sharding::ShardedSchema; +use crate::util::random_string; use crate::{ system_catalogs, EnumeratedDatabase, Memory, OmnishardedTable, PassthoughAuth, PreparedStatements, QueryParserEngine, QueryParserLevel, ReadWriteSplit, RewriteMode, Role, @@ -499,6 +500,49 @@ impl Config { result } + + /// Swap database configs between `source` and `destination`. + /// Uses tmp pattern: source -> tmp, destination -> source, tmp -> destination. + pub fn cutover(&mut self, source: &str, destination: &str) { + let tmp = format!("__tmp_{}__", random_string(12)); + + crate::swap_field!(self.databases.iter_mut(), name, source, destination, tmp); + crate::swap_field!( + self.sharded_mappings.iter_mut(), + database, + source, + destination, + tmp + ); + crate::swap_field!( + self.sharded_tables.iter_mut(), + database, + source, + destination, + tmp + ); + crate::swap_field!( + self.omnisharded_tables.iter_mut(), + database, + source, + destination, + tmp + ); + crate::swap_field!( + self.mirroring.iter_mut(), + source_db, + source, + destination, + tmp + ); + crate::swap_field!( + self.mirroring.iter_mut(), + destination_db, + source, + destination, + tmp + ); + } } #[cfg(test)] @@ -863,4 +907,276 @@ tables = ["my_table"] assert!(!db1_tables.iter().any(|t| t.name == "pg_class")); assert!(!db1_tables.iter().any(|t| t.name == "pg_attribute")); } + + #[test] + fn test_cutover_swaps_database_configs() { + let mut config = Config::default(); + config.databases = vec![ + Database { + name: "source_db".to_string(), + host: "source-host".to_string(), + port: 5432, + role: Role::Primary, + ..Default::default() + }, + Database { + name: "destination_db".to_string(), + host: "destination-host".to_string(), + port: 5433, + role: Role::Primary, + ..Default::default() + }, + ]; + + // After cutover: looking up source_db returns destination's config + config.cutover("source_db", "destination_db"); + + assert_eq!(config.databases.len(), 2); + + // source_db should now have destination's config (host, port) + let source = config + .databases + .iter() + .find(|d| d.name == "source_db") + .unwrap(); + assert_eq!( + source.host, "destination-host", + "source_db should now have destination's host after cutover" + ); + assert_eq!( + source.port, 5433, + "source_db should now have destination's port after cutover" + ); + + // destination_db should now have source's config (host, port) + let destination = config + .databases + .iter() + .find(|d| d.name == "destination_db") + .unwrap(); + assert_eq!( + destination.host, "source-host", + "destination_db should now have source's host after cutover" + ); + assert_eq!( + destination.port, 5432, + "destination_db should now have source's port after cutover" + ); + } + + #[test] + fn test_cutover_visual() { + let before = r#" +[[databases]] +name = "source_db" +host = "source-host-0" +port = 5432 +role = "primary" +shard = 0 + +[[databases]] +name = "source_db" +host = "source-host-0-replica" +port = 5432 +role = "replica" +shard = 0 + +[[databases]] +name = "source_db" +host = "source-host-1" +port = 5432 +role = "primary" +shard = 1 + +[[databases]] +name = "source_db" +host = "source-host-1-replica" +port = 5432 +role = "replica" +shard = 1 + +[[databases]] +name = "destination_db" +host = "destination-host-0" +port = 5433 +role = "primary" +shard = 0 + +[[databases]] +name = "destination_db" +host = "destination-host-0-replica" +port = 5433 +role = "replica" +shard = 0 + +[[databases]] +name = "destination_db" +host = "destination-host-1" +port = 5433 +role = "primary" +shard = 1 + +[[databases]] +name = "destination_db" +host = "destination-host-1-replica" +port = 5433 +role = "replica" +shard = 1 + +[[sharded_tables]] +database = "source_db" +name = "users" +column = "id" + +[[sharded_tables]] +database = "destination_db" +name = "users" +column = "id" + +[[mirroring]] +source_db = "source_db" +destination_db = "destination_db" +"#; + + // After name swap: elements stay in place, only names change + // Original source_db entries become destination_db (keeping source's host) + // Original destination_db entries become source_db (keeping destination's host) + let expected_after = r#" +[[databases]] +name = "destination_db" +host = "source-host-0" +port = 5432 +role = "primary" +shard = 0 + +[[databases]] +name = "destination_db" +host = "source-host-0-replica" +port = 5432 +role = "replica" +shard = 0 + +[[databases]] +name = "destination_db" +host = "source-host-1" +port = 5432 +role = "primary" +shard = 1 + +[[databases]] +name = "destination_db" +host = "source-host-1-replica" +port = 5432 +role = "replica" +shard = 1 + +[[databases]] +name = "source_db" +host = "destination-host-0" +port = 5433 +role = "primary" +shard = 0 + +[[databases]] +name = "source_db" +host = "destination-host-0-replica" +port = 5433 +role = "replica" +shard = 0 + +[[databases]] +name = "source_db" +host = "destination-host-1" +port = 5433 +role = "primary" +shard = 1 + +[[databases]] +name = "source_db" +host = "destination-host-1-replica" +port = 5433 +role = "replica" +shard = 1 + +[[sharded_tables]] +database = "destination_db" +name = "users" +column = "id" + +[[sharded_tables]] +database = "source_db" +name = "users" +column = "id" + +[[mirroring]] +source_db = "destination_db" +destination_db = "source_db" +"#; + + let mut config: Config = toml::from_str(before).unwrap(); + config.cutover("source_db", "destination_db"); + + let expected: Config = toml::from_str(expected_after).unwrap(); + + assert_eq!(config.databases, expected.databases); + assert_eq!(config.sharded_tables, expected.sharded_tables); + assert_eq!(config.mirroring, expected.mirroring); + } + + #[test] + fn test_cutover_backup_roundtrip() { + let original_toml = r#" +[[databases]] +name = "source_db" +host = "source-host" +port = 5432 +role = "primary" +shard = 0 + +[[databases]] +name = "destination_db" +host = "destination-host" +port = 5433 +role = "primary" +shard = 0 +"#; + + // Parse original config + let original: Config = toml::from_str(original_toml).unwrap(); + + // Simulate backup: serialize original to TOML + let backup_toml = toml::to_string_pretty(&original).unwrap(); + + // Perform cutover + let mut config = original.clone(); + config.cutover("source_db", "destination_db"); + + // Serialize cutover result (what would be written to disk) + let new_toml = toml::to_string_pretty(&config).unwrap(); + + // Verify backup can be parsed back and matches original + let restored_backup: Config = toml::from_str(&backup_toml).unwrap(); + assert_eq!(restored_backup.databases, original.databases); + + // Verify new config can be parsed back and has swapped values + let restored_new: Config = toml::from_str(&new_toml).unwrap(); + + // After cutover: source_db should have destination's host + let source = restored_new + .databases + .iter() + .find(|d| d.name == "source_db") + .unwrap(); + assert_eq!(source.host, "destination-host"); + assert_eq!(source.port, 5433); + + // After cutover: destination_db should have source's host + let dest = restored_new + .databases + .iter() + .find(|d| d.name == "destination_db") + .unwrap(); + assert_eq!(dest.host, "source-host"); + assert_eq!(dest.port, 5432); + } } diff --git a/pgdog-config/src/general.rs b/pgdog-config/src/general.rs index 8630232f3..b8f1ba823 100644 --- a/pgdog-config/src/general.rs +++ b/pgdog-config/src/general.rs @@ -5,7 +5,10 @@ use std::path::PathBuf; use std::time::Duration; use crate::pooling::ConnectionRecovery; -use crate::{CopyFormat, LoadSchema, QueryParserEngine, QueryParserLevel, SystemCatalogsBehavior}; +use crate::{ + CopyFormat, CutoverTimeoutAction, LoadSchema, QueryParserEngine, QueryParserLevel, + SystemCatalogsBehavior, +}; use super::auth::{AuthType, PassthoughAuth}; use super::database::{LoadBalancingStrategy, ReadWriteSplit, ReadWriteStrategy}; @@ -212,6 +215,24 @@ pub struct General { /// Load database schema. #[serde(default = "General::load_schema")] pub load_schema: LoadSchema, + /// Cutover maintenance threshold. + #[serde(default = "General::cutover_traffic_stop_threshold")] + pub cutover_traffic_stop_threshold: u64, + /// Cutover lag threshold. + #[serde(default = "General::cutover_replication_lag_threshold")] + pub cutover_replication_lag_threshold: u64, + /// Cutover last transaction delay. + #[serde(default = "General::cutover_last_transaction_delay")] + pub cutover_last_transaction_delay: u64, + /// Cutover timeout: how long to wait before doing a cutover anyway. + #[serde(default = "General::cutover_timeout")] + pub cutover_timeout: u64, + /// Cutover abort timeout: if cutover takes longer than this, abort. + #[serde(default = "General::cutover_timeout_action")] + pub cutover_timeout_action: CutoverTimeoutAction, + /// Cutover save config to disk. + #[serde(default)] + pub cutover_save_config: bool, } impl Default for General { @@ -286,6 +307,12 @@ impl Default for General { resharding_copy_format: CopyFormat::default(), reload_schema_on_ddl: Self::reload_schema_on_ddl(), load_schema: Self::load_schema(), + cutover_replication_lag_threshold: Self::cutover_replication_lag_threshold(), + cutover_traffic_stop_threshold: Self::cutover_traffic_stop_threshold(), + cutover_last_transaction_delay: Self::cutover_last_transaction_delay(), + cutover_timeout: Self::cutover_timeout(), + cutover_timeout_action: Self::cutover_timeout_action(), + cutover_save_config: bool::default(), } } } @@ -375,6 +402,29 @@ impl General { ) } + fn cutover_replication_lag_threshold() -> u64 { + Self::env_or_default("PGDOG_CUTOVER_REPLICATION_LAG_THRESHOLD", 1_000) + // 1KB + } + + fn cutover_traffic_stop_threshold() -> u64 { + Self::env_or_default("PGDOG_CUTOVER_TRAFFIC_STOP_THRESHOLD", 1_000_000) + // 1MB + } + + fn cutover_last_transaction_delay() -> u64 { + Self::env_or_default("PGDOG_CUTOVER_LAST_TRANSACTION_DELAY", 1_000) // 1 second + } + + fn cutover_timeout() -> u64 { + Self::env_or_default("PGDOG_CUTOVER_TIMEOUT", 30_000) + // 30 seconds + } + + fn cutover_timeout_action() -> CutoverTimeoutAction { + Self::env_enum_or_default("PGDOG_CUTOVER_TIMEOUT_ACTION") + } + fn rollback_timeout() -> u64 { Self::env_or_default("PGDOG_ROLLBACK_TIMEOUT", 5_000) } diff --git a/pgdog-config/src/sharding.rs b/pgdog-config/src/sharding.rs index 0d47d5d9b..866184e0a 100644 --- a/pgdog-config/src/sharding.rs +++ b/pgdog-config/src/sharding.rs @@ -391,6 +391,25 @@ impl FromStr for LoadSchema { } } +#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub enum CutoverTimeoutAction { + #[default] + Abort, + Cutover, +} + +impl FromStr for CutoverTimeoutAction { + type Err = (); + fn from_str(s: &str) -> Result { + Ok(match s.to_lowercase().as_str() { + "abort" => Self::Abort, + "cutover" => Self::Cutover, + _ => return Err(()), + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/pgdog-config/src/users.rs b/pgdog-config/src/users.rs index 9a55fe145..ab36f5af0 100644 --- a/pgdog-config/src/users.rs +++ b/pgdog-config/src/users.rs @@ -67,6 +67,14 @@ impl Users { } } } + + /// Swap user database references between source and destination. + /// Users on source become users on destination, and vice versa. + pub fn cutover(&mut self, source: &str, destination: &str) { + let tmp = format!("__tmp_{}__", random_string(12)); + + crate::swap_field!(self.users.iter_mut(), database, source, destination, tmp); + } } /// User allowed to connect to pgDog. @@ -193,3 +201,56 @@ fn admin_password() -> String { format!("_pgdog_{}", pw) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cutover_swaps_user_database_references() { + let mut users = Users { + users: vec![ + User::new("alice", "pass1", "source_db"), + User::new("bob", "pass2", "source_db"), + User::new("alice", "pass3", "destination_db"), + User::new("bob", "pass4", "destination_db"), + ], + ..Default::default() + }; + + // cutover swaps user database references + users.cutover("source_db", "destination_db"); + + assert_eq!(users.users.len(), 4); + + // Users that were on source_db should now be on destination_db + let alice_dest = users + .users + .iter() + .find(|u| u.name == "alice" && u.database == "destination_db") + .unwrap(); + assert_eq!(alice_dest.password(), "pass1"); + + let bob_dest = users + .users + .iter() + .find(|u| u.name == "bob" && u.database == "destination_db") + .unwrap(); + assert_eq!(bob_dest.password(), "pass2"); + + // Users that were on destination_db should now be on source_db + let alice_source = users + .users + .iter() + .find(|u| u.name == "alice" && u.database == "source_db") + .unwrap(); + assert_eq!(alice_source.password(), "pass3"); + + let bob_source = users + .users + .iter() + .find(|u| u.name == "bob" && u.database == "source_db") + .unwrap(); + assert_eq!(bob_source.password(), "pass4"); + } +} diff --git a/pgdog-config/src/util.rs b/pgdog-config/src/util.rs index a7f3c61ec..64c103418 100644 --- a/pgdog-config/src/util.rs +++ b/pgdog-config/src/util.rs @@ -52,3 +52,25 @@ pub fn random_string(n: usize) -> String { .map(char::from) .collect() } + +/// Swap field values using tmp pattern: source -> tmp, dest -> source, tmp -> dest. +#[macro_export] +macro_rules! swap_field { + ($iter:expr, $field:ident, $source:expr, $destination:expr, $tmp:expr) => { + $iter.for_each(|item| { + if item.$field == $source { + item.$field = $tmp.clone(); + } + }); + $iter.for_each(|item| { + if item.$field == $destination { + item.$field = $source.to_owned(); + } + }); + $iter.for_each(|item| { + if item.$field == $tmp { + item.$field = $destination.to_owned(); + } + }); + }; +} diff --git a/pgdog-plugin/src/bindings.rs b/pgdog-plugin/src/bindings.rs index 6f47703df..561d24e5b 100644 --- a/pgdog-plugin/src/bindings.rs +++ b/pgdog-plugin/src/bindings.rs @@ -1,338 +1,213 @@ /* automatically generated by rust-bindgen 0.71.1 */ +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __GLIBC_USE_ISOC2Y: u32 = 0; +pub const __GLIBC_USE_ISOC23: u32 = 0; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; pub const __WORDSIZE: u32 = 64; -pub const __has_safe_buffers: u32 = 1; -pub const __DARWIN_ONLY_64_BIT_INO_T: u32 = 1; -pub const __DARWIN_ONLY_UNIX_CONFORMANCE: u32 = 1; -pub const __DARWIN_ONLY_VERS_1050: u32 = 1; -pub const __DARWIN_UNIX03: u32 = 1; -pub const __DARWIN_64_BIT_INO_T: u32 = 1; -pub const __DARWIN_VERS_1050: u32 = 1; -pub const __DARWIN_NON_CANCELABLE: u32 = 0; -pub const __DARWIN_SUF_EXTSN: &[u8; 14] = b"$DARWIN_EXTSN\0"; -pub const __DARWIN_C_ANSI: u32 = 4096; -pub const __DARWIN_C_FULL: u32 = 900000; -pub const __DARWIN_C_LEVEL: u32 = 900000; -pub const __STDC_WANT_LIB_EXT1__: u32 = 1; -pub const __DARWIN_NO_LONG_LONG: u32 = 0; -pub const _DARWIN_FEATURE_64_BIT_INODE: u32 = 1; -pub const _DARWIN_FEATURE_ONLY_64_BIT_INODE: u32 = 1; -pub const _DARWIN_FEATURE_ONLY_VERS_1050: u32 = 1; -pub const _DARWIN_FEATURE_ONLY_UNIX_CONFORMANCE: u32 = 1; -pub const _DARWIN_FEATURE_UNIX_CONFORMANCE: u32 = 3; -pub const __has_ptrcheck: u32 = 0; -pub const USE_CLANG_TYPES: u32 = 0; -pub const __PTHREAD_SIZE__: u32 = 8176; -pub const __PTHREAD_ATTR_SIZE__: u32 = 56; -pub const __PTHREAD_MUTEXATTR_SIZE__: u32 = 8; -pub const __PTHREAD_MUTEX_SIZE__: u32 = 56; -pub const __PTHREAD_CONDATTR_SIZE__: u32 = 8; -pub const __PTHREAD_COND_SIZE__: u32 = 40; -pub const __PTHREAD_ONCE_SIZE__: u32 = 8; -pub const __PTHREAD_RWLOCK_SIZE__: u32 = 192; -pub const __PTHREAD_RWLOCKATTR_SIZE__: u32 = 16; -pub const INT8_MAX: u32 = 127; -pub const INT16_MAX: u32 = 32767; -pub const INT32_MAX: u32 = 2147483647; -pub const INT64_MAX: u64 = 9223372036854775807; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const __TIMESIZE: u32 = 64; +pub const __USE_TIME_BITS64: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_GETS: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_SCANF: u32 = 0; +pub const __GLIBC_USE_C23_STRTOL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_60559_BFP__: u32 = 201404; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; +pub const __STDC_ISO_10646__: u32 = 201706; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 42; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __glibc_c99_flexarr_available: u32 = 1; +pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; +pub const __HAVE_GENERIC_SELECTION: u32 = 1; +pub const __GLIBC_USE_LIB_EXT2: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT_C23: u32 = 0; +pub const __GLIBC_USE_IEC_60559_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT_C23: u32 = 0; +pub const __GLIBC_USE_IEC_60559_TYPES_EXT: u32 = 0; +pub const _BITS_TYPES_H: u32 = 1; +pub const _BITS_TYPESIZES_H: u32 = 1; +pub const __OFF_T_MATCHES_OFF64_T: u32 = 1; +pub const __INO_T_MATCHES_INO64_T: u32 = 1; +pub const __RLIM_T_MATCHES_RLIM64_T: u32 = 1; +pub const __STATFS_MATCHES_STATFS64: u32 = 1; +pub const __KERNEL_OLD_TIMEVAL_MATCHES_TIMEVAL64: u32 = 1; +pub const __FD_SETSIZE: u32 = 1024; +pub const _BITS_TIME64_H: u32 = 1; +pub const _BITS_WCHAR_H: u32 = 1; +pub const _BITS_STDINT_INTN_H: u32 = 1; +pub const _BITS_STDINT_UINTN_H: u32 = 1; +pub const _BITS_STDINT_LEAST_H: u32 = 1; pub const INT8_MIN: i32 = -128; pub const INT16_MIN: i32 = -32768; pub const INT32_MIN: i32 = -2147483648; -pub const INT64_MIN: i64 = -9223372036854775808; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; pub const UINT8_MAX: u32 = 255; pub const UINT16_MAX: u32 = 65535; pub const UINT32_MAX: u32 = 4294967295; -pub const UINT64_MAX: i32 = -1; pub const INT_LEAST8_MIN: i32 = -128; pub const INT_LEAST16_MIN: i32 = -32768; pub const INT_LEAST32_MIN: i32 = -2147483648; -pub const INT_LEAST64_MIN: i64 = -9223372036854775808; pub const INT_LEAST8_MAX: u32 = 127; pub const INT_LEAST16_MAX: u32 = 32767; pub const INT_LEAST32_MAX: u32 = 2147483647; -pub const INT_LEAST64_MAX: u64 = 9223372036854775807; pub const UINT_LEAST8_MAX: u32 = 255; pub const UINT_LEAST16_MAX: u32 = 65535; pub const UINT_LEAST32_MAX: u32 = 4294967295; -pub const UINT_LEAST64_MAX: i32 = -1; pub const INT_FAST8_MIN: i32 = -128; -pub const INT_FAST16_MIN: i32 = -32768; -pub const INT_FAST32_MIN: i32 = -2147483648; -pub const INT_FAST64_MIN: i64 = -9223372036854775808; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; pub const INT_FAST8_MAX: u32 = 127; -pub const INT_FAST16_MAX: u32 = 32767; -pub const INT_FAST32_MAX: u32 = 2147483647; -pub const INT_FAST64_MAX: u64 = 9223372036854775807; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; pub const UINT_FAST8_MAX: u32 = 255; -pub const UINT_FAST16_MAX: u32 = 65535; -pub const UINT_FAST32_MAX: u32 = 4294967295; -pub const UINT_FAST64_MAX: i32 = -1; -pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; pub const UINTPTR_MAX: i32 = -1; -pub const SIZE_MAX: i32 = -1; -pub const RSIZE_MAX: i32 = -1; -pub const WINT_MIN: i32 = -2147483648; -pub const WINT_MAX: u32 = 2147483647; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; pub const SIG_ATOMIC_MIN: i32 = -2147483648; pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; pub type wchar_t = ::std::os::raw::c_int; -pub type max_align_t = f64; -pub type int_least8_t = i8; -pub type int_least16_t = i16; -pub type int_least32_t = i32; -pub type int_least64_t = i64; -pub type uint_least8_t = u8; -pub type uint_least16_t = u16; -pub type uint_least32_t = u32; -pub type uint_least64_t = u64; -pub type int_fast8_t = i8; -pub type int_fast16_t = i16; -pub type int_fast32_t = i32; -pub type int_fast64_t = i64; -pub type uint_fast8_t = u8; -pub type uint_fast16_t = u16; -pub type uint_fast32_t = u32; -pub type uint_fast64_t = u64; -pub type __int8_t = ::std::os::raw::c_schar; -pub type __uint8_t = ::std::os::raw::c_uchar; -pub type __int16_t = ::std::os::raw::c_short; -pub type __uint16_t = ::std::os::raw::c_ushort; -pub type __int32_t = ::std::os::raw::c_int; -pub type __uint32_t = ::std::os::raw::c_uint; -pub type __int64_t = ::std::os::raw::c_longlong; -pub type __uint64_t = ::std::os::raw::c_ulonglong; -pub type __darwin_intptr_t = ::std::os::raw::c_long; -pub type __darwin_natural_t = ::std::os::raw::c_uint; -pub type __darwin_ct_rune_t = ::std::os::raw::c_int; -#[repr(C)] -#[derive(Copy, Clone)] -pub union __mbstate_t { - pub __mbstate8: [::std::os::raw::c_char; 128usize], - pub _mbstateL: ::std::os::raw::c_longlong, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of __mbstate_t"][::std::mem::size_of::<__mbstate_t>() - 128usize]; - ["Alignment of __mbstate_t"][::std::mem::align_of::<__mbstate_t>() - 8usize]; - ["Offset of field: __mbstate_t::__mbstate8"] - [::std::mem::offset_of!(__mbstate_t, __mbstate8) - 0usize]; - ["Offset of field: __mbstate_t::_mbstateL"] - [::std::mem::offset_of!(__mbstate_t, _mbstateL) - 0usize]; -}; -pub type __darwin_mbstate_t = __mbstate_t; -pub type __darwin_ptrdiff_t = ::std::os::raw::c_long; -pub type __darwin_size_t = ::std::os::raw::c_ulong; -pub type __darwin_va_list = __builtin_va_list; -pub type __darwin_wchar_t = ::std::os::raw::c_int; -pub type __darwin_rune_t = __darwin_wchar_t; -pub type __darwin_wint_t = ::std::os::raw::c_int; -pub type __darwin_clock_t = ::std::os::raw::c_ulong; -pub type __darwin_socklen_t = __uint32_t; -pub type __darwin_ssize_t = ::std::os::raw::c_long; -pub type __darwin_time_t = ::std::os::raw::c_long; -pub type __darwin_blkcnt_t = __int64_t; -pub type __darwin_blksize_t = __int32_t; -pub type __darwin_dev_t = __int32_t; -pub type __darwin_fsblkcnt_t = ::std::os::raw::c_uint; -pub type __darwin_fsfilcnt_t = ::std::os::raw::c_uint; -pub type __darwin_gid_t = __uint32_t; -pub type __darwin_id_t = __uint32_t; -pub type __darwin_ino64_t = __uint64_t; -pub type __darwin_ino_t = __darwin_ino64_t; -pub type __darwin_mach_port_name_t = __darwin_natural_t; -pub type __darwin_mach_port_t = __darwin_mach_port_name_t; -pub type __darwin_mode_t = __uint16_t; -pub type __darwin_off_t = __int64_t; -pub type __darwin_pid_t = __int32_t; -pub type __darwin_sigset_t = __uint32_t; -pub type __darwin_suseconds_t = __int32_t; -pub type __darwin_uid_t = __uint32_t; -pub type __darwin_useconds_t = __uint32_t; -pub type __darwin_uuid_t = [::std::os::raw::c_uchar; 16usize]; -pub type __darwin_uuid_string_t = [::std::os::raw::c_char; 37usize]; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __darwin_pthread_handler_rec { - pub __routine: ::std::option::Option, - pub __arg: *mut ::std::os::raw::c_void, - pub __next: *mut __darwin_pthread_handler_rec, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of __darwin_pthread_handler_rec"] - [::std::mem::size_of::<__darwin_pthread_handler_rec>() - 24usize]; - ["Alignment of __darwin_pthread_handler_rec"] - [::std::mem::align_of::<__darwin_pthread_handler_rec>() - 8usize]; - ["Offset of field: __darwin_pthread_handler_rec::__routine"] - [::std::mem::offset_of!(__darwin_pthread_handler_rec, __routine) - 0usize]; - ["Offset of field: __darwin_pthread_handler_rec::__arg"] - [::std::mem::offset_of!(__darwin_pthread_handler_rec, __arg) - 8usize]; - ["Offset of field: __darwin_pthread_handler_rec::__next"] - [::std::mem::offset_of!(__darwin_pthread_handler_rec, __next) - 16usize]; -}; #[repr(C)] +#[repr(align(16))] #[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_attr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 56usize], +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: u128, } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of _opaque_pthread_attr_t"][::std::mem::size_of::<_opaque_pthread_attr_t>() - 64usize]; - ["Alignment of _opaque_pthread_attr_t"] - [::std::mem::align_of::<_opaque_pthread_attr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_attr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_attr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_attr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_attr_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_cond_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 40usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_cond_t"][::std::mem::size_of::<_opaque_pthread_cond_t>() - 48usize]; - ["Alignment of _opaque_pthread_cond_t"] - [::std::mem::align_of::<_opaque_pthread_cond_t>() - 8usize]; - ["Offset of field: _opaque_pthread_cond_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_cond_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_cond_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_cond_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_condattr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 8usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_condattr_t"] - [::std::mem::size_of::<_opaque_pthread_condattr_t>() - 16usize]; - ["Alignment of _opaque_pthread_condattr_t"] - [::std::mem::align_of::<_opaque_pthread_condattr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_condattr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_condattr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_condattr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_condattr_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_mutex_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 56usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_mutex_t"][::std::mem::size_of::<_opaque_pthread_mutex_t>() - 64usize]; - ["Alignment of _opaque_pthread_mutex_t"] - [::std::mem::align_of::<_opaque_pthread_mutex_t>() - 8usize]; - ["Offset of field: _opaque_pthread_mutex_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_mutex_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_mutex_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_mutex_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_mutexattr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 8usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_mutexattr_t"] - [::std::mem::size_of::<_opaque_pthread_mutexattr_t>() - 16usize]; - ["Alignment of _opaque_pthread_mutexattr_t"] - [::std::mem::align_of::<_opaque_pthread_mutexattr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_mutexattr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_mutexattr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_once_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 8usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_once_t"][::std::mem::size_of::<_opaque_pthread_once_t>() - 16usize]; - ["Alignment of _opaque_pthread_once_t"] - [::std::mem::align_of::<_opaque_pthread_once_t>() - 8usize]; - ["Offset of field: _opaque_pthread_once_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_once_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_once_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_once_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_rwlock_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 192usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_rwlock_t"] - [::std::mem::size_of::<_opaque_pthread_rwlock_t>() - 200usize]; - ["Alignment of _opaque_pthread_rwlock_t"] - [::std::mem::align_of::<_opaque_pthread_rwlock_t>() - 8usize]; - ["Offset of field: _opaque_pthread_rwlock_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_rwlock_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_rwlockattr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 16usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_rwlockattr_t"] - [::std::mem::size_of::<_opaque_pthread_rwlockattr_t>() - 24usize]; - ["Alignment of _opaque_pthread_rwlockattr_t"] - [::std::mem::align_of::<_opaque_pthread_rwlockattr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_rwlockattr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_rwlockattr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __opaque) - 8usize]; + ["Size of max_align_t"][::std::mem::size_of::() - 32usize]; + ["Alignment of max_align_t"][::std::mem::align_of::() - 16usize]; + ["Offset of field: max_align_t::__clang_max_align_nonce1"] + [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce1) - 0usize]; + ["Offset of field: max_align_t::__clang_max_align_nonce2"] + [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce2) - 16usize]; }; +pub type __u_char = ::std::os::raw::c_uchar; +pub type __u_short = ::std::os::raw::c_ushort; +pub type __u_int = ::std::os::raw::c_uint; +pub type __u_long = ::std::os::raw::c_ulong; +pub type __int8_t = ::std::os::raw::c_schar; +pub type __uint8_t = ::std::os::raw::c_uchar; +pub type __int16_t = ::std::os::raw::c_short; +pub type __uint16_t = ::std::os::raw::c_ushort; +pub type __int32_t = ::std::os::raw::c_int; +pub type __uint32_t = ::std::os::raw::c_uint; +pub type __int64_t = ::std::os::raw::c_long; +pub type __uint64_t = ::std::os::raw::c_ulong; +pub type __int_least8_t = __int8_t; +pub type __uint_least8_t = __uint8_t; +pub type __int_least16_t = __int16_t; +pub type __uint_least16_t = __uint16_t; +pub type __int_least32_t = __int32_t; +pub type __uint_least32_t = __uint32_t; +pub type __int_least64_t = __int64_t; +pub type __uint_least64_t = __uint64_t; +pub type __quad_t = ::std::os::raw::c_long; +pub type __u_quad_t = ::std::os::raw::c_ulong; +pub type __intmax_t = ::std::os::raw::c_long; +pub type __uintmax_t = ::std::os::raw::c_ulong; +pub type __dev_t = ::std::os::raw::c_ulong; +pub type __uid_t = ::std::os::raw::c_uint; +pub type __gid_t = ::std::os::raw::c_uint; +pub type __ino_t = ::std::os::raw::c_ulong; +pub type __ino64_t = ::std::os::raw::c_ulong; +pub type __mode_t = ::std::os::raw::c_uint; +pub type __nlink_t = ::std::os::raw::c_ulong; +pub type __off_t = ::std::os::raw::c_long; +pub type __off64_t = ::std::os::raw::c_long; +pub type __pid_t = ::std::os::raw::c_int; #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_t { - pub __sig: ::std::os::raw::c_long, - pub __cleanup_stack: *mut __darwin_pthread_handler_rec, - pub __opaque: [::std::os::raw::c_char; 8176usize], +pub struct __fsid_t { + pub __val: [::std::os::raw::c_int; 2usize], } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of _opaque_pthread_t"][::std::mem::size_of::<_opaque_pthread_t>() - 8192usize]; - ["Alignment of _opaque_pthread_t"][::std::mem::align_of::<_opaque_pthread_t>() - 8usize]; - ["Offset of field: _opaque_pthread_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_t::__cleanup_stack"] - [::std::mem::offset_of!(_opaque_pthread_t, __cleanup_stack) - 8usize]; - ["Offset of field: _opaque_pthread_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_t, __opaque) - 16usize]; + ["Size of __fsid_t"][::std::mem::size_of::<__fsid_t>() - 8usize]; + ["Alignment of __fsid_t"][::std::mem::align_of::<__fsid_t>() - 4usize]; + ["Offset of field: __fsid_t::__val"][::std::mem::offset_of!(__fsid_t, __val) - 0usize]; }; -pub type __darwin_pthread_attr_t = _opaque_pthread_attr_t; -pub type __darwin_pthread_cond_t = _opaque_pthread_cond_t; -pub type __darwin_pthread_condattr_t = _opaque_pthread_condattr_t; -pub type __darwin_pthread_key_t = ::std::os::raw::c_ulong; -pub type __darwin_pthread_mutex_t = _opaque_pthread_mutex_t; -pub type __darwin_pthread_mutexattr_t = _opaque_pthread_mutexattr_t; -pub type __darwin_pthread_once_t = _opaque_pthread_once_t; -pub type __darwin_pthread_rwlock_t = _opaque_pthread_rwlock_t; -pub type __darwin_pthread_rwlockattr_t = _opaque_pthread_rwlockattr_t; -pub type __darwin_pthread_t = *mut _opaque_pthread_t; -pub type intmax_t = ::std::os::raw::c_long; -pub type uintmax_t = ::std::os::raw::c_ulong; +pub type __clock_t = ::std::os::raw::c_long; +pub type __rlim_t = ::std::os::raw::c_ulong; +pub type __rlim64_t = ::std::os::raw::c_ulong; +pub type __id_t = ::std::os::raw::c_uint; +pub type __time_t = ::std::os::raw::c_long; +pub type __useconds_t = ::std::os::raw::c_uint; +pub type __suseconds_t = ::std::os::raw::c_long; +pub type __suseconds64_t = ::std::os::raw::c_long; +pub type __daddr_t = ::std::os::raw::c_int; +pub type __key_t = ::std::os::raw::c_int; +pub type __clockid_t = ::std::os::raw::c_int; +pub type __timer_t = *mut ::std::os::raw::c_void; +pub type __blksize_t = ::std::os::raw::c_long; +pub type __blkcnt_t = ::std::os::raw::c_long; +pub type __blkcnt64_t = ::std::os::raw::c_long; +pub type __fsblkcnt_t = ::std::os::raw::c_ulong; +pub type __fsblkcnt64_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt64_t = ::std::os::raw::c_ulong; +pub type __fsword_t = ::std::os::raw::c_long; +pub type __ssize_t = ::std::os::raw::c_long; +pub type __syscall_slong_t = ::std::os::raw::c_long; +pub type __syscall_ulong_t = ::std::os::raw::c_ulong; +pub type __loff_t = __off64_t; +pub type __caddr_t = *mut ::std::os::raw::c_char; +pub type __intptr_t = ::std::os::raw::c_long; +pub type __socklen_t = ::std::os::raw::c_uint; +pub type __sig_atomic_t = ::std::os::raw::c_int; +pub type int_least8_t = __int_least8_t; +pub type int_least16_t = __int_least16_t; +pub type int_least32_t = __int_least32_t; +pub type int_least64_t = __int_least64_t; +pub type uint_least8_t = __uint_least8_t; +pub type uint_least16_t = __uint_least16_t; +pub type uint_least32_t = __uint_least32_t; +pub type uint_least64_t = __uint_least64_t; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = __intmax_t; +pub type uintmax_t = __uintmax_t; #[doc = " Wrapper around Rust's [`&str`], without allocating memory, unlike [`std::ffi::CString`].\n The caller must use it as a Rust string. This is not a C-string."] #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -449,4 +324,3 @@ const _: () = { ["Offset of field: PdRoute::shard"][::std::mem::offset_of!(PdRoute, shard) - 0usize]; ["Offset of field: PdRoute::read_write"][::std::mem::offset_of!(PdRoute, read_write) - 8usize]; }; -pub type __builtin_va_list = *mut ::std::os::raw::c_char; diff --git a/pgdog-stats/src/replication.rs b/pgdog-stats/src/replication.rs index 8601d2cd6..ee5a3ce56 100644 --- a/pgdog-stats/src/replication.rs +++ b/pgdog-stats/src/replication.rs @@ -7,7 +7,9 @@ use pgdog_postgres_types::Error; use pgdog_postgres_types::{Format, FromDataType, TimestampTz}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Default, Copy, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[derive( + Debug, Clone, Default, Copy, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, +)] pub struct Lsn { pub high: i64, pub low: i64, diff --git a/pgdog/src/admin/copy_data.rs b/pgdog/src/admin/copy_data.rs new file mode 100644 index 000000000..13c4fd307 --- /dev/null +++ b/pgdog/src/admin/copy_data.rs @@ -0,0 +1,77 @@ +//! COPY_DATA command. + +use tokio::spawn; +use tracing::info; + +use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::backend::replication::orchestrator::Orchestrator; +use crate::backend::replication::AsyncTasks; + +use super::prelude::*; + +pub struct CopyData { + pub from_database: String, + pub to_database: String, + pub publication: String, + pub replication_slot: Option, +} + +#[async_trait] +impl Command for CopyData { + fn name(&self) -> String { + "COPY_DATA".into() + } + + fn parse(sql: &str) -> Result { + let parts = sql.split(" ").collect::>(); + + match parts[..] { + ["copy_data", from_database, to_database, publication] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: None, + }), + ["copy_data", from_database, to_database, publication, replication_slot] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: Some(replication_slot.to_owned()), + }), + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + info!( + r#"copy_data "{}" to "{}", publication="{}""#, + self.from_database, self.to_database, self.publication + ); + + let mut orchestrator = Orchestrator::new( + &self.from_database, + &self.to_database, + &self.publication, + self.replication_slot.clone(), + )?; + + let slot_name = orchestrator.replication_slot().to_owned(); + + let task_id = Task::register(TaskType::CopyData(spawn(async move { + orchestrator.load_schema().await?; + orchestrator.data_sync().await?; + AsyncTasks::insert(TaskType::Replication(orchestrator.replicate().await?)); + + Ok(()) + }))); + + let mut dr = DataRow::new(); + dr.add(task_id.to_string()).add(slot_name); + + Ok(vec![ + RowDescription::new(&[Field::text("task_id"), Field::text("replication_slot")]) + .message()?, + dr.message()?, + ]) + } +} diff --git a/pgdog/src/admin/cutover.rs b/pgdog/src/admin/cutover.rs new file mode 100644 index 000000000..c5f8f8435 --- /dev/null +++ b/pgdog/src/admin/cutover.rs @@ -0,0 +1,33 @@ +use crate::backend::replication::logical::admin::AsyncTasks; + +use super::prelude::*; + +pub struct Cutover; + +#[async_trait] +impl Command for Cutover { + fn name(&self) -> String { + "CUTOVER".into() + } + + fn parse(sql: &str) -> Result { + let parts: Vec<&str> = sql.split_whitespace().collect(); + + match parts[..] { + ["cutover"] => Ok(Cutover), + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + AsyncTasks::cutover()?; + + let mut dr = DataRow::new(); + dr.add("OK"); + + Ok(vec![ + RowDescription::new(&[Field::text("cutover")]).message()?, + dr.message()?, + ]) + } +} diff --git a/pgdog/src/admin/error.rs b/pgdog/src/admin/error.rs index c9ca3f29b..e11573cda 100644 --- a/pgdog/src/admin/error.rs +++ b/pgdog/src/admin/error.rs @@ -36,6 +36,15 @@ pub enum Error { #[error("address is not valid")] InvalidAddress, + + #[error("{0}")] + Replication(Box), +} + +impl From for Error { + fn from(err: crate::backend::replication::logical::Error) -> Self { + Error::Replication(Box::new(err)) + } } impl From for Error { diff --git a/pgdog/src/admin/mod.rs b/pgdog/src/admin/mod.rs index 521119c8c..6be79945f 100644 --- a/pgdog/src/admin/mod.rs +++ b/pgdog/src/admin/mod.rs @@ -5,6 +5,8 @@ use async_trait::async_trait; use crate::net::messages::Message; pub mod ban; +pub mod copy_data; +pub mod cutover; pub mod error; pub mod healthcheck; pub mod maintenance_mode; @@ -15,7 +17,10 @@ pub mod prelude; pub mod probe; pub mod reconnect; pub mod reload; +pub mod replicate; pub mod reset_query_cache; +pub mod reshard; +pub mod schema_sync; pub mod server; pub mod set; pub mod setup_schema; @@ -30,12 +35,17 @@ pub mod show_pools; pub mod show_prepared_statements; pub mod show_query_cache; pub mod show_replication; +pub mod show_replication_slots; +pub mod show_schema_sync; pub mod show_server_memory; pub mod show_servers; pub mod show_stats; +pub mod show_table_copies; +pub mod show_tasks; pub mod show_transactions; pub mod show_version; pub mod shutdown; +pub mod stop_task; pub use error::Error; diff --git a/pgdog/src/admin/parser.rs b/pgdog/src/admin/parser.rs index 5d94ab445..215ab1230 100644 --- a/pgdog/src/admin/parser.rs +++ b/pgdog/src/admin/parser.rs @@ -1,16 +1,18 @@ //! Admin command parser. use super::{ - ban::Ban, healthcheck::Healthcheck, maintenance_mode::MaintenanceMode, pause::Pause, - prelude::Message, probe::Probe, reconnect::Reconnect, reload::Reload, - reset_query_cache::ResetQueryCache, set::Set, setup_schema::SetupSchema, + ban::Ban, copy_data::CopyData, cutover::Cutover, healthcheck::Healthcheck, + maintenance_mode::MaintenanceMode, pause::Pause, prelude::Message, probe::Probe, + reconnect::Reconnect, reload::Reload, replicate::Replicate, reset_query_cache::ResetQueryCache, + reshard::Reshard, schema_sync::SchemaSync, set::Set, setup_schema::SetupSchema, show_client_memory::ShowClientMemory, show_clients::ShowClients, show_config::ShowConfig, show_instance_id::ShowInstanceId, show_lists::ShowLists, show_mirrors::ShowMirrors, show_peers::ShowPeers, show_pools::ShowPools, show_prepared_statements::ShowPreparedStatements, show_query_cache::ShowQueryCache, show_replication::ShowReplication, + show_replication_slots::ShowReplicationSlots, show_schema_sync::ShowSchemaSync, show_server_memory::ShowServerMemory, show_servers::ShowServers, show_stats::ShowStats, - show_transactions::ShowTransactions, show_version::ShowVersion, shutdown::Shutdown, Command, - Error, + show_table_copies::ShowTableCopies, show_tasks::ShowTasks, show_transactions::ShowTransactions, + show_version::ShowVersion, shutdown::Shutdown, stop_task::StopTask, Command, Error, }; use tracing::debug; @@ -39,11 +41,21 @@ pub enum ParseResult { ShowReplication(ShowReplication), ShowServerMemory(ShowServerMemory), ShowClientMemory(ShowClientMemory), + ShowTableCopies(ShowTableCopies), + ShowReplicationSlots(ShowReplicationSlots), + ShowSchemaSync(ShowSchemaSync), Set(Set), Ban(Ban), Probe(Probe), MaintenanceMode(MaintenanceMode), Healthcheck(Healthcheck), + Reshard(Reshard), + SchemaSync(SchemaSync), + CopyData(CopyData), + Replicate(Replicate), + ShowTasks(ShowTasks), + StopTask(StopTask), + Cutover(Cutover), } impl ParseResult { @@ -74,11 +86,21 @@ impl ParseResult { ShowReplication(show_replication) => show_replication.execute().await, ShowServerMemory(show_server_memory) => show_server_memory.execute().await, ShowClientMemory(show_client_memory) => show_client_memory.execute().await, + ShowTableCopies(show_table_copies) => show_table_copies.execute().await, + ShowReplicationSlots(cmd) => cmd.execute().await, + ShowSchemaSync(cmd) => cmd.execute().await, Set(set) => set.execute().await, Ban(ban) => ban.execute().await, Probe(probe) => probe.execute().await, MaintenanceMode(maintenance_mode) => maintenance_mode.execute().await, Healthcheck(healthcheck) => healthcheck.execute().await, + Reshard(reshard) => reshard.execute().await, + SchemaSync(cmd) => cmd.execute().await, + CopyData(cmd) => cmd.execute().await, + Replicate(cmd) => cmd.execute().await, + ShowTasks(cmd) => cmd.execute().await, + StopTask(cmd) => cmd.execute().await, + Cutover(cmd) => cmd.execute().await, } } @@ -109,11 +131,21 @@ impl ParseResult { ShowReplication(show_replication) => show_replication.name(), ShowServerMemory(show_server_memory) => show_server_memory.name(), ShowClientMemory(show_client_memory) => show_client_memory.name(), + ShowTableCopies(show_table_copies) => show_table_copies.name(), + ShowReplicationSlots(cmd) => cmd.name(), + ShowSchemaSync(cmd) => cmd.name(), Set(set) => set.name(), Ban(ban) => ban.name(), Probe(probe) => probe.name(), MaintenanceMode(maintenance_mode) => maintenance_mode.name(), Healthcheck(healthcheck) => healthcheck.name(), + Reshard(reshard) => reshard.name(), + SchemaSync(cmd) => cmd.name(), + CopyData(cmd) => cmd.name(), + Replicate(cmd) => cmd.name(), + ShowTasks(cmd) => cmd.name(), + StopTask(cmd) => cmd.name(), + Cutover(cmd) => cmd.name(), } } } @@ -163,6 +195,12 @@ impl Parser { "lists" => ParseResult::ShowLists(ShowLists::parse(&sql)?), "prepared" => ParseResult::ShowPrepared(ShowPreparedStatements::parse(&sql)?), "replication" => ParseResult::ShowReplication(ShowReplication::parse(&sql)?), + "replication_slots" => { + ParseResult::ShowReplicationSlots(ShowReplicationSlots::parse(&sql)?) + } + "schema_sync" => ParseResult::ShowSchemaSync(ShowSchemaSync::parse(&sql)?), + "table_copies" => ParseResult::ShowTableCopies(ShowTableCopies::parse(&sql)?), + "tasks" => ParseResult::ShowTasks(ShowTasks::parse(&sql)?), command => { debug!("unknown admin show command: '{}'", command); return Err(Error::Syntax); @@ -182,6 +220,12 @@ impl Parser { return Err(Error::Syntax); } }, + "reshard" => ParseResult::Reshard(Reshard::parse(&sql)?), + "schema_sync" => ParseResult::SchemaSync(SchemaSync::parse(&sql)?), + "copy_data" => ParseResult::CopyData(CopyData::parse(&sql)?), + "replicate" => ParseResult::Replicate(Replicate::parse(&sql)?), + "stop_task" => ParseResult::StopTask(StopTask::parse(&sql)?), + "cutover" => ParseResult::Cutover(Cutover::parse(&sql)?), "probe" => ParseResult::Probe(Probe::parse(&sql)?), "maintenance" => ParseResult::MaintenanceMode(MaintenanceMode::parse(&sql)?), // TODO: This is not ready yet. We have a race and @@ -229,4 +273,10 @@ mod tests { let result = Parser::parse("SHOW CLIENT MEMORY;"); assert!(matches!(result, Ok(ParseResult::ShowClientMemory(_)))); } + + #[test] + fn parses_cutover_command() { + let result = Parser::parse("CUTOVER"); + assert!(matches!(result, Ok(ParseResult::Cutover(_)))); + } } diff --git a/pgdog/src/admin/replicate.rs b/pgdog/src/admin/replicate.rs new file mode 100644 index 000000000..64497798a --- /dev/null +++ b/pgdog/src/admin/replicate.rs @@ -0,0 +1,67 @@ +//! REPLICATE command. + +use tracing::info; + +use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::backend::replication::orchestrator::Orchestrator; + +use super::prelude::*; + +pub struct Replicate { + pub from_database: String, + pub to_database: String, + pub publication: String, + pub replication_slot: Option, +} + +#[async_trait] +impl Command for Replicate { + fn name(&self) -> String { + "REPLICATE".into() + } + + fn parse(sql: &str) -> Result { + let parts = sql.split(" ").collect::>(); + + match parts[..] { + ["replicate", from_database, to_database, publication] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: None, + }), + ["replicate", from_database, to_database, publication, replication_slot] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: Some(replication_slot.to_owned()), + }), + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + info!( + r#"replicate "{}" to "{}", publication="{}""#, + self.from_database, self.to_database, self.publication + ); + + let orchestrator = Orchestrator::new( + &self.from_database, + &self.to_database, + &self.publication, + self.replication_slot.clone(), + )?; + + let waiter = orchestrator.replicate().await?; + let task_id = Task::register(TaskType::Replication(waiter)); + + let mut dr = DataRow::new(); + dr.add(task_id.to_string()); + + Ok(vec![ + RowDescription::new(&[Field::text("task_id")]).message()?, + dr.message()?, + ]) + } +} diff --git a/pgdog/src/admin/reshard.rs b/pgdog/src/admin/reshard.rs new file mode 100644 index 000000000..afce53780 --- /dev/null +++ b/pgdog/src/admin/reshard.rs @@ -0,0 +1,58 @@ +//! RESHARD command. + +use tracing::info; + +use crate::backend::replication::orchestrator::Orchestrator; + +use super::prelude::*; + +pub struct Reshard { + pub from_database: String, + pub to_database: String, + pub publication: String, + pub replication_slot: Option, +} + +#[async_trait] +impl Command for Reshard { + fn name(&self) -> String { + "RESHARD".into() + } + + fn parse(sql: &str) -> Result { + let parts = sql.split(" ").collect::>(); + + match parts[..] { + ["reshard", from_database, to_database, publication] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: None, + }), + ["reshard", from_database, to_database, publication, replication_slot] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: Some(replication_slot.to_owned()), + }), + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + info!( + r#"resharding "{}" to "{}", publication="{}""#, + self.from_database, self.to_database, self.publication + ); + let mut orchestrator = Orchestrator::new( + &self.from_database, + &self.to_database, + &self.publication, + self.replication_slot.clone(), + )?; + + orchestrator.replicate_and_cutover().await?; + + Ok(vec![]) + } +} diff --git a/pgdog/src/admin/schema_sync.rs b/pgdog/src/admin/schema_sync.rs new file mode 100644 index 000000000..a73bdf767 --- /dev/null +++ b/pgdog/src/admin/schema_sync.rs @@ -0,0 +1,106 @@ +//! SCHEMA_SYNC command. + +use tokio::spawn; +use tracing::info; + +use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::backend::replication::orchestrator::Orchestrator; + +use super::prelude::*; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum SchemaSyncPhase { + Pre, + Post, +} + +pub struct SchemaSync { + pub from_database: String, + pub to_database: String, + pub publication: String, + pub replication_slot: Option, + pub phase: SchemaSyncPhase, +} + +#[async_trait] +impl Command for SchemaSync { + fn name(&self) -> String { + match self.phase { + SchemaSyncPhase::Pre => "SCHEMA_SYNC PRE".into(), + SchemaSyncPhase::Post => "SCHEMA_SYNC POST".into(), + } + } + + fn parse(sql: &str) -> Result { + let parts = sql.split(" ").collect::>(); + + match parts[..] { + ["schema_sync", phase, from_database, to_database, publication] => Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: None, + phase: parse_phase(phase)?, + }), + ["schema_sync", phase, from_database, to_database, publication, replication_slot] => { + Ok(Self { + from_database: from_database.to_owned(), + to_database: to_database.to_owned(), + publication: publication.to_owned(), + replication_slot: Some(replication_slot.to_owned()), + phase: parse_phase(phase)?, + }) + } + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + let phase_name = match self.phase { + SchemaSyncPhase::Pre => "pre", + SchemaSyncPhase::Post => "post", + }; + + info!( + r#"schema_sync {} "{}" to "{}", publication="{}""#, + phase_name, self.from_database, self.to_database, self.publication + ); + + let mut orchestrator = Orchestrator::new( + &self.from_database, + &self.to_database, + &self.publication, + self.replication_slot.clone(), + )?; + + let phase = self.phase; + let handle = spawn(async move { + orchestrator.load_schema().await?; + + match phase { + SchemaSyncPhase::Pre => orchestrator.schema_sync_pre(true).await, + SchemaSyncPhase::Post => orchestrator.schema_sync_post(true).await, + }?; + + Ok(()) + }); + + let task_id = Task::register(TaskType::SchemaSync(handle)); + + let mut dr = DataRow::new(); + dr.add(task_id.to_string()); + + Ok(vec![ + RowDescription::new(&[Field::text("task_id")]).message()?, + dr.message()?, + ]) + } +} + +fn parse_phase(phase: &str) -> Result { + match phase { + "pre" => Ok(SchemaSyncPhase::Pre), + "post" => Ok(SchemaSyncPhase::Post), + _ => Err(Error::Syntax), + } +} diff --git a/pgdog/src/admin/show_replication_slots.rs b/pgdog/src/admin/show_replication_slots.rs new file mode 100644 index 000000000..761971172 --- /dev/null +++ b/pgdog/src/admin/show_replication_slots.rs @@ -0,0 +1,78 @@ +use std::time::SystemTime; + +use chrono::{DateTime, Local}; + +use crate::{ + backend::replication::logical::status::ReplicationSlots, + net::{data_row::Data, ToDataRowColumn}, + util::{format_bytes, format_time}, +}; + +use super::prelude::*; + +pub struct ShowReplicationSlots; + +#[async_trait] +impl Command for ShowReplicationSlots { + fn name(&self) -> String { + "SHOW REPLICATION_SLOTS".into() + } + + fn parse(_sql: &str) -> Result { + Ok(ShowReplicationSlots {}) + } + + async fn execute(&self) -> Result, Error> { + let rd = RowDescription::new(&[ + Field::text("host"), + Field::bigint("port"), + Field::text("database_name"), + Field::text("name"), + Field::text("lsn"), + Field::text("lag"), + Field::bigint("lag_bytes"), + Field::bool("copy_data"), + Field::text("last_transaction"), + Field::bigint("last_transaction_ms"), + ]); + let mut messages = vec![rd.message()?]; + let now = SystemTime::now(); + + for entry in ReplicationSlots::get().iter() { + let slot = entry.value(); + + let last_transaction_ms = slot + .last_transaction + .and_then(|t| now.duration_since(t).ok()) + .map(|d| d.as_millis() as i64); + + let last_transaction_str = slot + .last_transaction + .map(|t| format_time(DateTime::::from(t))); + + let mut row = DataRow::new(); + row.add(&slot.address.host) + .add(slot.address.port as i64) + .add(&slot.address.database_name) + .add(slot.name.as_str()) + .add(slot.lsn.to_string().as_str()) + .add(format_bytes(slot.lag as u64).as_str()) + .add(slot.lag) + .add(slot.copy_data) + .add(if let Some(ref s) = last_transaction_str { + s.as_str().to_data_row_column() + } else { + Data::null() + }) + .add(if let Some(ms) = last_transaction_ms { + ms.to_data_row_column() + } else { + Data::null() + }); + + messages.push(row.message()?); + } + + Ok(messages) + } +} diff --git a/pgdog/src/admin/show_schema_sync.rs b/pgdog/src/admin/show_schema_sync.rs new file mode 100644 index 000000000..fd77199a5 --- /dev/null +++ b/pgdog/src/admin/show_schema_sync.rs @@ -0,0 +1,71 @@ +use std::time::SystemTime; + +use chrono::{DateTime, Local}; + +use crate::{ + backend::replication::logical::status::SchemaStatements, + util::{format_time, human_duration_display}, +}; + +use super::prelude::*; + +pub struct ShowSchemaSync; + +#[async_trait] +impl Command for ShowSchemaSync { + fn name(&self) -> String { + "SHOW SCHEMA_SYNC".into() + } + + fn parse(_sql: &str) -> Result { + Ok(ShowSchemaSync {}) + } + + async fn execute(&self) -> Result, Error> { + let rd = RowDescription::new(&[ + Field::text("database"), + Field::text("user"), + Field::bigint("shard"), + Field::text("kind"), + Field::text("sync_state"), + Field::text("started_at"), + Field::text("elapsed"), + Field::bigint("elapsed_ms"), + Field::text("table_schema"), + Field::text("table_name"), + Field::text("sql"), + ]); + let mut messages = vec![rd.message()?]; + let now = SystemTime::now(); + + for entry in SchemaStatements::get().iter() { + let stmt = entry.key(); + + let elapsed = now.duration_since(stmt.started_at).unwrap_or_default(); + let elapsed_ms = elapsed.as_millis() as i64; + let elapsed_human = human_duration_display(elapsed); + + let kind = stmt.kind.to_string(); + let sync_state = stmt.sync_state.to_string(); + let started_at: DateTime = stmt.started_at.into(); + let started_at = format_time(started_at); + + let mut row = DataRow::new(); + row.add(stmt.user.database.as_str()) + .add(stmt.user.user.as_str()) + .add(stmt.shard as i64) + .add(kind.as_str()) + .add(sync_state.as_str()) + .add(started_at.as_str()) + .add(elapsed_human.as_str()) + .add(elapsed_ms) + .add(stmt.table_schema.as_deref().unwrap_or("")) + .add(stmt.table_name.as_deref().unwrap_or("")) + .add(stmt.sql.as_str()); + + messages.push(row.message()?); + } + + Ok(messages) + } +} diff --git a/pgdog/src/admin/show_table_copies.rs b/pgdog/src/admin/show_table_copies.rs new file mode 100644 index 000000000..83935cbcb --- /dev/null +++ b/pgdog/src/admin/show_table_copies.rs @@ -0,0 +1,79 @@ +use std::time::SystemTime; + +use crate::backend::replication::logical::status::TableCopies; +use crate::util::{format_bytes, human_duration_display, number_human}; + +use super::prelude::*; + +pub struct ShowTableCopies; + +#[async_trait] +impl Command for ShowTableCopies { + fn name(&self) -> String { + "SHOW TABLE_COPIES".into() + } + + fn parse(_sql: &str) -> Result { + Ok(ShowTableCopies {}) + } + + async fn execute(&self) -> Result, Error> { + let rd = RowDescription::new(&[ + Field::text("schema"), + Field::text("table"), + Field::text("status"), + Field::bigint("rows"), + Field::text("rows_human"), + Field::bigint("bytes"), + Field::text("bytes_human"), + Field::bigint("bytes_per_sec"), + Field::text("bytes_per_sec_human"), + Field::text("elapsed"), + Field::bigint("elapsed_ms"), + Field::text("sql"), + ]); + let mut messages = vec![rd.message()?]; + let now = SystemTime::now(); + + let table_copies = TableCopies::get(); + let mut entries: Vec<_> = table_copies.iter().collect(); + entries.sort_by_key(|e| if e.value().bytes == 0 { 1 } else { 0 }); + + for entry in entries { + let key = entry.key(); + let state = entry.value(); + + let elapsed = now.duration_since(state.last_update).unwrap_or_default(); + let elapsed_ms = elapsed.as_millis() as i64; + let elapsed_human = human_duration_display(elapsed); + + let status = if state.bytes == 0 { + "waiting" + } else { + "running" + }; + + let rows_human = number_human(state.rows as u64); + let bytes_human = format_bytes(state.bytes as u64); + let bytes_per_sec_human = format_bytes(state.bytes_per_sec as u64); + + let mut row = DataRow::new(); + row.add(key.schema.as_str()) + .add(key.table.as_str()) + .add(status) + .add(state.rows as i64) + .add(rows_human.as_str()) + .add(state.bytes as i64) + .add(bytes_human.as_str()) + .add(state.bytes_per_sec as i64) + .add(bytes_per_sec_human.as_str()) + .add(elapsed_human.as_str()) + .add(elapsed_ms) + .add(state.sql.as_str()); + + messages.push(row.message()?); + } + + Ok(messages) + } +} diff --git a/pgdog/src/admin/show_tasks.rs b/pgdog/src/admin/show_tasks.rs new file mode 100644 index 000000000..3b56d6aaa --- /dev/null +++ b/pgdog/src/admin/show_tasks.rs @@ -0,0 +1,51 @@ +use std::time::SystemTime; + +use chrono::{DateTime, Local}; + +use crate::backend::replication::logical::admin::AsyncTasks; +use crate::util::{format_time, human_duration_display}; + +use super::prelude::*; + +pub struct ShowTasks; + +#[async_trait] +impl Command for ShowTasks { + fn name(&self) -> String { + "SHOW TASKS".into() + } + + fn parse(_sql: &str) -> Result { + Ok(ShowTasks) + } + + async fn execute(&self) -> Result, Error> { + let rd = RowDescription::new(&[ + Field::bigint("id"), + Field::text("type"), + Field::text("started_at"), + Field::text("elapsed"), + Field::bigint("elapsed_ms"), + ]); + let mut messages = vec![rd.message()?]; + let now = SystemTime::now(); + + for (id, task_kind, started_at) in AsyncTasks::get().iter() { + let elapsed = now.duration_since(started_at).unwrap_or_default(); + let elapsed_ms = elapsed.as_millis() as i64; + let elapsed_str = human_duration_display(elapsed); + + let started_at_str = format_time(DateTime::::from(started_at)); + + let mut row = DataRow::new(); + row.add(id as i64) + .add(task_kind.to_string().as_str()) + .add(started_at_str.as_str()) + .add(elapsed_str.as_str()) + .add(elapsed_ms); + messages.push(row.message()?); + } + + Ok(messages) + } +} diff --git a/pgdog/src/admin/stop_task.rs b/pgdog/src/admin/stop_task.rs new file mode 100644 index 000000000..70be0b857 --- /dev/null +++ b/pgdog/src/admin/stop_task.rs @@ -0,0 +1,56 @@ +use crate::backend::replication::logical::admin::{AsyncTasks, TaskKind}; +use crate::net::messages::{ErrorResponse, NoticeResponse}; + +use super::prelude::*; + +pub struct StopTask { + task_id: u64, +} + +#[async_trait] +impl Command for StopTask { + fn name(&self) -> String { + "STOP_TASK".into() + } + + fn parse(sql: &str) -> Result { + let parts: Vec<&str> = sql.split_whitespace().collect(); + + match parts[..] { + ["stop_task", id] => { + let task_id = id.parse().map_err(|_| Error::Syntax)?; + Ok(StopTask { task_id }) + } + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + let task_kind = AsyncTasks::remove(self.task_id); + + let mut messages = vec![]; + + if task_kind == Some(TaskKind::CopyData) { + let notice = NoticeResponse::from(ErrorResponse { + severity: "WARNING".into(), + code: "01000".into(), + message: "replication slot was not dropped and requires manual cleanup".into(), + ..Default::default() + }); + messages.push(notice.message()?); + } + + let result = match task_kind { + Some(_) => "OK", + None => "task not found", + }; + + let mut dr = DataRow::new(); + dr.add(result); + + messages.push(RowDescription::new(&[Field::text("stop_task")]).message()?); + messages.push(dr.message()?); + + Ok(messages) + } +} diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index d8028c353..7607ea975 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -1,9 +1,11 @@ //! Databases behind pgDog. use std::collections::{hash_map::Entry, HashMap}; +use std::ops::Deref; use std::sync::Arc; use arc_swap::ArcSwap; +use futures::future::try_join_all; use once_cell::sync::Lazy; use parking_lot::lock_api::MutexGuard; use parking_lot::{Mutex, RawMutex}; @@ -104,6 +106,20 @@ pub fn shutdown() { databases().shutdown(); } +/// Cancel all queries running on a database. +pub async fn cancel_all(database: &str) -> Result<(), Error> { + let clusters: Vec<_> = databases() + .all() + .iter() + .filter(|(user, _)| user.database == database) + .map(|(_, cluster)| cluster.clone()) + .collect(); + + try_join_all(clusters.iter().map(|cluster| cluster.cancel_all())).await?; + + Ok(()) +} + /// Re-create pools from config. pub fn reload() -> Result<(), Error> { let old_config = config(); @@ -157,6 +173,72 @@ pub(crate) fn add(mut user: crate::config::User) { } } +/// Swap database configs between source and destination. +/// Both databases keep their names, but their configs (host, port, etc.) are exchanged. +/// User database references are also swapped. +/// Persists changes to disk (best effort). +pub async fn cutover(source: &str, destination: &str) -> Result<(), Error> { + use tokio::fs::{copy, write}; + + let config = { + let _lock = lock(); + + let mut config = config().deref().clone(); + + config.config.cutover(source, destination); + config.users.cutover(source, destination); + + let databases = from_config(&config); + + replace_databases(databases, true)?; + + config + }; + + info!(r#"databases swapped: "{}" <-> "{}""#, source, destination); + + if config.config.general.cutover_save_config { + if let Err(err) = copy( + &config.config_path, + config.config_path.clone().with_extension("bak.toml"), + ) + .await + { + warn!( + "{} is read-only, skipping config persistence (err: {})", + config + .config_path + .parent() + .map(|path| path.to_owned()) + .unwrap_or_default() + .display(), + err + ); + return Ok(()); + } + + copy( + &config.users_path, + &config.users_path.clone().with_extension("bak.toml"), + ) + .await?; + + write( + &config.config_path, + toml::to_string_pretty(&config.config)?.as_bytes(), + ) + .await?; + + write( + &config.users_path, + toml::to_string_pretty(&config.users)?.as_bytes(), + ) + .await?; + } + + Ok(()) +} + /// Database/user pair that identifies a database cluster pool. #[derive(Debug, PartialEq, Hash, Eq, Clone, Default)] pub struct User { @@ -196,15 +278,6 @@ impl ToUser for (&str, Option<&str>) { } } -// impl ToUser for &pgdog_config::User { -// fn to_user(&self) -> User { -// User { -// user: self.name.clone(), -// database: self.database.clone(), -// } -// } -// } - /// Databases. #[derive(Default, Clone)] pub struct Databases { @@ -1570,4 +1643,99 @@ mod tests { assert_eq!(databases.all().len(), 1); } + + #[tokio::test] + async fn test_cutover_persists_to_disk() { + use tempfile::TempDir; + use tokio::fs; + + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path().join("pgdog.toml"); + let users_path = temp_dir.path().join("users.toml"); + + let original_config = r#" +[[databases]] +name = "source_db" +host = "source-host" +port = 5432 +role = "primary" + +[[databases]] +name = "destination_db" +host = "destination-host" +port = 5433 +role = "primary" +"#; + + let original_users = r#" +[[users]] +name = "testuser" +database = "source_db" +password = "testpass" +"#; + + fs::write(&config_path, original_config).await.unwrap(); + fs::write(&users_path, original_users).await.unwrap(); + + // Load config from temp files and set in global state + let mut config = crate::config::ConfigAndUsers::load(&config_path, &users_path).unwrap(); + config.config.general.cutover_save_config = true; + crate::config::set(config).unwrap(); + + // Call the actual cutover function + cutover("source_db", "destination_db").await.unwrap(); + + // Verify backup files contain original content + let backup_config_str = fs::read_to_string(config_path.with_extension("bak.toml")) + .await + .unwrap(); + let backup_config: crate::config::Config = toml::from_str(&backup_config_str).unwrap(); + let backup_source = backup_config + .databases + .iter() + .find(|d| d.name == "source_db") + .unwrap(); + assert_eq!(backup_source.host, "source-host"); + assert_eq!(backup_source.port, 5432); + let backup_dest = backup_config + .databases + .iter() + .find(|d| d.name == "destination_db") + .unwrap(); + assert_eq!(backup_dest.host, "destination-host"); + assert_eq!(backup_dest.port, 5433); + + let backup_users_str = fs::read_to_string(users_path.with_extension("bak.toml")) + .await + .unwrap(); + let backup_users: crate::config::Users = toml::from_str(&backup_users_str).unwrap(); + assert_eq!(backup_users.users.len(), 1); + assert_eq!(backup_users.users[0].name, "testuser"); + assert_eq!(backup_users.users[0].database, "source_db"); + + // Verify new config files have swapped values + let new_config_str = fs::read_to_string(&config_path).await.unwrap(); + let new_config: crate::config::Config = toml::from_str(&new_config_str).unwrap(); + let new_source = new_config + .databases + .iter() + .find(|d| d.name == "source_db") + .unwrap(); + assert_eq!(new_source.host, "destination-host"); + assert_eq!(new_source.port, 5433); + let new_dest = new_config + .databases + .iter() + .find(|d| d.name == "destination_db") + .unwrap(); + assert_eq!(new_dest.host, "source-host"); + assert_eq!(new_dest.port, 5432); + + // Verify users were swapped + let new_users_str = fs::read_to_string(&users_path).await.unwrap(); + let new_users: crate::config::Users = toml::from_str(&new_users_str).unwrap(); + assert_eq!(new_users.users.len(), 1); + assert_eq!(new_users.users[0].name, "testuser"); + assert_eq!(new_users.users[0].database, "destination_db"); + } } diff --git a/pgdog/src/backend/error.rs b/pgdog/src/backend/error.rs index f24651985..dee0c2c91 100644 --- a/pgdog/src/backend/error.rs +++ b/pgdog/src/backend/error.rs @@ -134,6 +134,9 @@ pub enum Error { #[error("unsupported aggregation {function}: {reason}")] UnsupportedAggregation { function: String, reason: String }, + + #[error("toml: {0}")] + TomlSer(#[from] toml::ser::Error), } impl From for Error { diff --git a/pgdog/src/backend/maintenance_mode.rs b/pgdog/src/backend/maintenance_mode.rs index 8d82ba231..34ffb4ced 100644 --- a/pgdog/src/backend/maintenance_mode.rs +++ b/pgdog/src/backend/maintenance_mode.rs @@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use once_cell::sync::Lazy; use tokio::sync::{futures::Notified, Notify}; +use tracing::warn; static MAINTENANCE_MODE: Lazy = Lazy::new(|| MaintenanceMode { notify: Notify::new(), @@ -23,11 +24,18 @@ pub(crate) fn waiter() -> Option> { pub fn start() { MAINTENANCE_MODE.on.store(true, Ordering::Relaxed); + warn!("maintenance mode is on"); } pub fn stop() { MAINTENANCE_MODE.on.store(false, Ordering::Relaxed); MAINTENANCE_MODE.notify.notify_waiters(); + warn!("maintenance mode is off"); +} + +#[cfg(test)] +pub fn is_on() -> bool { + MAINTENANCE_MODE.on.load(Ordering::Relaxed) } #[derive(Debug)] diff --git a/pgdog/src/backend/pool/address.rs b/pgdog/src/backend/pool/address.rs index 7b86fc0c9..63b141163 100644 --- a/pgdog/src/backend/pool/address.rs +++ b/pgdog/src/backend/pool/address.rs @@ -8,7 +8,7 @@ use crate::backend::{pool::dns_cache::DnsCache, Error}; use crate::config::{config, Database, User}; /// Server address. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Eq, Hash)] pub struct Address { /// Server host. pub host: String, diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 33c0798a9..0b7c1e25a 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -1,5 +1,6 @@ //! A collection of replicas and a primary. +use futures::future::try_join_all; use parking_lot::Mutex; use pgdog_config::{ LoadSchema, PreparedStatements, QueryParserEngine, QueryParserLevel, Rewrite, RewriteMode, @@ -579,6 +580,21 @@ impl Cluster { self.readiness.online.store(false, Ordering::Relaxed); } + /// Send a cancellation request for all running queries. + pub(crate) async fn cancel_all(&self) -> Result<(), Error> { + let pools: Vec<_> = self + .shards() + .iter() + .flat_map(|shard| shard.pools()) + .collect(); + + try_join_all(pools.iter().map(|pool| pool.cancel_all())) + .await + .map_err(|_| Error::FastShutdown)?; + + Ok(()) + } + /// Is the cluster online? pub(crate) fn online(&self) -> bool { self.readiness.online.load(Ordering::Relaxed) diff --git a/pgdog/src/backend/pool/error.rs b/pgdog/src/backend/pool/error.rs index c1fddccb8..5f6d1338f 100644 --- a/pgdog/src/backend/pool/error.rs +++ b/pgdog/src/backend/pool/error.rs @@ -73,4 +73,7 @@ pub enum Error { #[error("mapping missing: {0}")] MappingMissing(usize), + + #[error("fast shutdown failed")] + FastShutdown, } diff --git a/pgdog/src/backend/pool/inner.rs b/pgdog/src/backend/pool/inner.rs index a93c31f39..f5601ca4b 100644 --- a/pgdog/src/backend/pool/inner.rs +++ b/pgdog/src/backend/pool/inner.rs @@ -115,6 +115,11 @@ impl Inner { self.taken.len() } + /// Get backend IDs for all currently checked out servers. + pub(super) fn checked_out_server_ids(&self) -> Vec { + self.taken.servers() + } + /// Find the server currently linked to this client, if any. #[inline] pub(super) fn peer(&self, client_id: &BackendKeyData) -> Option { diff --git a/pgdog/src/backend/pool/pool_impl.rs b/pgdog/src/backend/pool/pool_impl.rs index 77f9c3d2a..597663aa3 100644 --- a/pgdog/src/backend/pool/pool_impl.rs +++ b/pgdog/src/backend/pool/pool_impl.rs @@ -4,6 +4,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; +use futures::future::try_join_all; use once_cell::sync::{Lazy, OnceCell}; use parking_lot::RwLock; use parking_lot::{lock_api::MutexGuard, Mutex, RawMutex}; @@ -326,6 +327,18 @@ impl Pool { guard.dump_idle(); } + /// Send a cancellation request for all running queries. + pub async fn cancel_all(&self) -> Result<(), Error> { + let taken = self.lock().checked_out_server_ids(); + let addr = self.addr().clone(); + + try_join_all(taken.iter().map(|id| Server::cancel(&addr, id))) + .await + .map_err(|_| Error::FastShutdown)?; + + Ok(()) + } + /// Resume the pool. pub fn resume(&self) { { diff --git a/pgdog/src/backend/pool/taken.rs b/pgdog/src/backend/pool/taken.rs index 62c66740b..b791a9c1c 100644 --- a/pgdog/src/backend/pool/taken.rs +++ b/pgdog/src/backend/pool/taken.rs @@ -58,6 +58,13 @@ impl Taken { self.client_server.get(client).copied() } + pub(super) fn servers(&self) -> Vec { + self.client_server + .iter() + .map(|(_, server)| server.clone()) + .collect() + } + #[cfg(test)] pub(super) fn clear(&mut self) { self.taken.clear(); diff --git a/pgdog/src/backend/replication/logical/admin.rs b/pgdog/src/backend/replication/logical/admin.rs new file mode 100644 index 000000000..362b87482 --- /dev/null +++ b/pgdog/src/backend/replication/logical/admin.rs @@ -0,0 +1,351 @@ +use std::{ + fmt, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::SystemTime, +}; + +use crate::backend::replication::orchestrator::ReplicationWaiter; + +use super::Error; +use dashmap::DashMap; +use once_cell::sync::Lazy; +use tokio::{ + select, spawn, + sync::{oneshot, Notify}, + task::JoinHandle, +}; +use tracing::error; + +static TASKS: Lazy = Lazy::new(AsyncTasks::default); + +pub struct Task; + +impl Task { + pub(crate) fn register(task: TaskType) -> u64 { + AsyncTasks::insert(task) + } +} + +pub enum TaskType { + SchemaSync(JoinHandle>), + CopyData(JoinHandle>), + Replication(ReplicationWaiter), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TaskKind { + SchemaSync, + CopyData, + Replication, +} + +impl fmt::Display for TaskKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TaskKind::SchemaSync => write!(f, "schema_sync"), + TaskKind::CopyData => write!(f, "copy_data"), + TaskKind::Replication => write!(f, "replication"), + } + } +} + +pub struct TaskInfo { + #[allow(dead_code)] + abort_tx: oneshot::Sender<()>, + cutover: Arc, + pub task_kind: TaskKind, + pub started_at: SystemTime, +} + +#[derive(Clone, Default)] +pub struct AsyncTasks { + tasks: Arc>, + counter: Arc, +} + +impl AsyncTasks { + pub fn get() -> Self { + TASKS.clone() + } + + /// Perform cutover. + pub fn cutover() -> Result<(), Error> { + let this = Self::get(); + let task = this + .tasks + .iter() + .find(|t| t.task_kind == TaskKind::Replication) + .ok_or(Error::NotReplication)?; + + task.cutover.notify_one(); + + Ok(()) + } + + pub fn insert(task: TaskType) -> u64 { + let this = Self::get(); + let id = this.counter.fetch_add(1, Ordering::SeqCst); + let (abort_tx, abort_rx) = oneshot::channel(); + + match task { + TaskType::SchemaSync(handle) => { + this.tasks.insert( + id, + TaskInfo { + abort_tx, + cutover: Arc::new(Notify::new()), + task_kind: TaskKind::SchemaSync, + started_at: SystemTime::now(), + }, + ); + let abort_handle = handle.abort_handle(); + spawn(async move { + select! { + _ = abort_rx => { + abort_handle.abort(); + } + result = handle => { + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => error!("[task: {}] {}", id, err), + Err(err) => error!("[task: {}] {}", id, err), + } + } + } + AsyncTasks::get().tasks.remove(&id); + }); + } + + TaskType::CopyData(handle) => { + this.tasks.insert( + id, + TaskInfo { + abort_tx, + cutover: Arc::new(Notify::new()), + task_kind: TaskKind::CopyData, + started_at: SystemTime::now(), + }, + ); + let abort_handle = handle.abort_handle(); + spawn(async move { + select! { + _ = abort_rx => { + abort_handle.abort(); + } + result = handle => { + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => error!("[task: {}] {}", id, err), + Err(err) => error!("[task: {}] {}", id, err), + } + } + } + AsyncTasks::get().tasks.remove(&id); + }); + } + + TaskType::Replication(mut waiter) => { + let cutover = Arc::new(Notify::new()); + + this.tasks.insert( + id, + TaskInfo { + abort_tx, + cutover: cutover.clone(), + task_kind: TaskKind::Replication, + started_at: SystemTime::now(), + }, + ); + + spawn(async move { + select! { + _ = abort_rx => { + waiter.stop(); + } + + _ = cutover.notified() => { + if let Err(err) = waiter.cutover().await { + error!("[task: {}] {}", id, err); + } + } + + result = waiter.wait() => { + if let Err(err) = result { + error!("[task: {}] {}", id, err); + } + } + } + + AsyncTasks::get().tasks.remove(&id); + }); + } + } + + id + } + + pub fn remove(id: u64) -> Option { + // Dropping the sender signals abort to the waiting task + Self::get() + .tasks + .remove(&id) + .map(|(_, info)| info.task_kind) + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.tasks + .iter() + .map(|e| (*e.key(), e.value().task_kind, e.value().started_at)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + use tokio::time::sleep; + + #[test] + fn test_task_kind_display() { + assert_eq!(TaskKind::SchemaSync.to_string(), "schema_sync"); + assert_eq!(TaskKind::CopyData.to_string(), "copy_data"); + assert_eq!(TaskKind::Replication.to_string(), "replication"); + } + + #[tokio::test] + async fn test_task_registration_and_removal() { + // Create a task that completes immediately + let handle = spawn(async { Ok::<(), Error>(()) }); + let id = Task::register(TaskType::SchemaSync(handle)); + + // Task should be visible briefly + // Give it a moment to register + sleep(Duration::from_millis(10)).await; + + // Try to remove it - it may already be gone if it completed + let result = AsyncTasks::remove(id); + // Either we removed it, or it already completed and removed itself + assert!(result.is_none() || result == Some(TaskKind::SchemaSync)); + } + + #[tokio::test] + async fn test_task_abort_via_remove() { + // Create a long-running task + let handle = spawn(async { + sleep(Duration::from_secs(60)).await; + Ok::<(), Error>(()) + }); + let id = Task::register(TaskType::CopyData(handle)); + + // Give it time to register + sleep(Duration::from_millis(10)).await; + + // Remove should abort the task + let result = AsyncTasks::remove(id); + assert_eq!(result, Some(TaskKind::CopyData)); + + // Task should be gone now + sleep(Duration::from_millis(50)).await; + let result = AsyncTasks::remove(id); + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_task_iter() { + // Create multiple tasks + let handle1 = spawn(async { + sleep(Duration::from_secs(60)).await; + Ok::<(), Error>(()) + }); + let handle2 = spawn(async { + sleep(Duration::from_secs(60)).await; + Ok::<(), Error>(()) + }); + + let id1 = Task::register(TaskType::SchemaSync(handle1)); + let id2 = Task::register(TaskType::CopyData(handle2)); + + sleep(Duration::from_millis(10)).await; + + // Should see both tasks + let tasks: Vec<_> = AsyncTasks::get().iter().collect(); + let task_ids: Vec<_> = tasks.iter().map(|(id, _, _)| *id).collect(); + assert!(task_ids.contains(&id1)); + assert!(task_ids.contains(&id2)); + + // Verify task kinds + for (id, kind, _) in &tasks { + if *id == id1 { + assert_eq!(*kind, TaskKind::SchemaSync); + } else if *id == id2 { + assert_eq!(*kind, TaskKind::CopyData); + } + } + + // Cleanup + AsyncTasks::remove(id1); + AsyncTasks::remove(id2); + } + + #[tokio::test] + async fn test_task_auto_cleanup_on_completion() { + // Create a task that completes quickly + let handle = spawn(async { + sleep(Duration::from_millis(10)).await; + Ok::<(), Error>(()) + }); + let id = Task::register(TaskType::SchemaSync(handle)); + + // Wait for task to complete and cleanup + sleep(Duration::from_millis(100)).await; + + // Task should have removed itself + let result = AsyncTasks::remove(id); + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_cutover_fails_without_replication_task() { + // Create a non-replication task + let handle = spawn(async { + sleep(Duration::from_secs(60)).await; + Ok::<(), Error>(()) + }); + let id = Task::register(TaskType::SchemaSync(handle)); + sleep(Duration::from_millis(10)).await; + + // Cutover should fail because there's no replication task + let result = AsyncTasks::cutover(); + assert!(matches!(result, Err(Error::NotReplication)), "{:?}", result); + + // Cleanup + AsyncTasks::remove(id); + } + + #[tokio::test] + async fn test_cutover_returns_not_found_when_no_replication_task() { + // Register several non-replication tasks + let mut task_ids = vec![]; + for _ in 0..5 { + let handle = spawn(async { + sleep(Duration::from_secs(60)).await; + Ok::<(), Error>(()) + }); + task_ids.push(Task::register(TaskType::SchemaSync(handle))); + } + + sleep(Duration::from_millis(10)).await; + + // With only non-replication tasks, cutover should return TaskNotFound + let result = AsyncTasks::cutover(); + assert!(matches!(result, Err(Error::NotReplication))); + + // Cleanup + for id in task_ids { + AsyncTasks::remove(id); + } + } +} diff --git a/pgdog/src/backend/replication/logical/error.rs b/pgdog/src/backend/replication/logical/error.rs index 4c7451cee..91c60c6c6 100644 --- a/pgdog/src/backend/replication/logical/error.rs +++ b/pgdog/src/backend/replication/logical/error.rs @@ -86,6 +86,39 @@ pub enum Error { #[error("router returned incorrect command")] IncorrectCommand, + + #[error("schema: {0}")] + SchemaSync(Box), + + #[error("schema isn't loaded")] + NoSchema, + + #[error("config wasn't updated with new cluster")] + NoNewCluster, + + #[error("tokio: {0}")] + JoinError(#[from] tokio::task::JoinError), + + #[error("copy for table {0} been aborted")] + CopyAborted(PublicationTable), + + #[error("data sync has been aborted")] + DataSyncAborted, + + #[error("replication has been aborted")] + ReplicationAborted, + + #[error("waiter has no publisher")] + NoPublisher, + + #[error("cutover abort timeout")] + AbortTimeout, + + #[error("task not found")] + TaskNotFound, + + #[error("task is not a replication task")] + NotReplication, } impl From for Error { @@ -93,3 +126,9 @@ impl From for Error { Self::PgError(Box::new(value)) } } + +impl From for Error { + fn from(value: crate::backend::schema::sync::error::Error) -> Self { + Self::SchemaSync(Box::new(value)) + } +} diff --git a/pgdog/src/backend/replication/logical/mod.rs b/pgdog/src/backend/replication/logical/mod.rs index 34822bf3d..c587f6e3d 100644 --- a/pgdog/src/backend/replication/logical/mod.rs +++ b/pgdog/src/backend/replication/logical/mod.rs @@ -1,10 +1,22 @@ +pub mod admin; pub mod copy_statement; pub mod error; +pub mod orchestrator; pub mod publisher; +pub mod status; pub mod subscriber; +pub use admin::*; pub use copy_statement::CopyStatement; pub use error::Error; -pub use publisher::publisher_impl::Publisher; +pub use publisher::publisher_impl::{Publisher, Waiter}; pub use subscriber::{CopySubscriber, StreamSubscriber}; + +use crate::{ + backend::{ + databases::{databases, reload_from_existing}, + schema::sync::SyncState, + }, + config::config, +}; diff --git a/pgdog/src/backend/replication/logical/orchestrator.rs b/pgdog/src/backend/replication/logical/orchestrator.rs new file mode 100644 index 000000000..75175f303 --- /dev/null +++ b/pgdog/src/backend/replication/logical/orchestrator.rs @@ -0,0 +1,651 @@ +use crate::{ + backend::{ + databases::{cancel_all, cutover}, + maintenance_mode, + schema::sync::{pg_dump::PgDumpOutput, PgDump}, + Cluster, + }, + util::{format_bytes, human_duration, random_string}, +}; +use pgdog_config::{ConfigAndUsers, CutoverTimeoutAction}; +use std::{fmt::Display, sync::Arc, time::Duration}; +use tokio::{ + select, + sync::Mutex, + time::{interval, Instant}, +}; +use tracing::{info, warn}; + +use super::*; + +#[derive(Debug, Clone)] +pub(crate) struct Orchestrator { + source: Cluster, + destination: Cluster, + publication: String, + schema: Option, + publisher: Arc>, + replication_slot: String, +} + +impl Orchestrator { + /// Create new orchestrator. + pub(crate) fn new( + source: &str, + destination: &str, + publication: &str, + replication_slot: Option, + ) -> Result { + let source = databases().schema_owner(source)?; + let destination = databases().schema_owner(destination)?; + + let replication_slot = replication_slot + .unwrap_or(format!("__pgdog_repl_{}", random_string(19).to_lowercase())); + + let mut orchestrator = Self { + source, + destination, + publication: publication.to_owned(), + schema: None, + publisher: Arc::new(Mutex::new(Publisher::default())), + replication_slot, + }; + + orchestrator.refresh_publisher(); + + Ok(orchestrator) + } + + fn refresh(&mut self) -> Result<(), Error> { + self.source = databases().schema_owner(&self.source.identifier().database)?; + self.destination = databases().schema_owner(&self.destination.identifier().database)?; + self.refresh_publisher(); + + Ok(()) + } + + fn refresh_publisher(&mut self) { + let publisher = Publisher::new( + &self.source, + &self.publication, + config().config.general.query_parser_engine, + self.replication_slot.clone(), + ); + self.publisher = Arc::new(Mutex::new(publisher)); + } + + pub(crate) fn replication_slot(&self) -> &str { + &self.replication_slot + } + + pub(crate) async fn load_schema(&mut self) -> Result<(), Error> { + let pg_dump = PgDump::new(&self.source, &self.publication); + let output = pg_dump.dump().await?; + self.schema = Some(output); + + Ok(()) + } + + /// Schema getter. + pub(crate) fn schema(&self) -> Result<&PgDumpOutput, Error> { + self.schema.as_ref().ok_or(Error::NoSchema) + } + + pub(crate) async fn schema_sync_pre(&mut self, ignore_errors: bool) -> Result<(), Error> { + let schema = self.schema.as_ref().ok_or(Error::NoSchema)?; + + schema + .restore(&self.destination, ignore_errors, SyncState::PreData) + .await?; + + // Schema changed on the destination. + reload_from_existing()?; + + self.destination = databases().schema_owner(&self.destination.identifier().database)?; + self.source = databases().schema_owner(&self.source.identifier().database)?; + self.destination.wait_schema_loaded().await; + + self.refresh_publisher(); + + Ok(()) + } + + /// Remove any blockers for reverse replication. + pub(crate) async fn schema_sync_post_cutover( + &mut self, + ignore_errors: bool, + ) -> Result<(), Error> { + let schema = self.schema.as_ref().ok_or(Error::NoSchema)?; + + schema + .restore(&self.destination, ignore_errors, SyncState::PostCutover) + .await?; + + Ok(()) + } + + pub(crate) async fn data_sync(&self) -> Result<(), Error> { + let mut publisher = self.publisher.lock().await; + + // Run data sync for all tables in parallel using multiple replicas, + // if available. + publisher.data_sync(&self.destination).await?; + + Ok(()) + } + + /// Replicate forever. + /// + /// Useful for CLI interface only, since this will never stop. + /// + pub(crate) async fn replicate(&self) -> Result { + let mut publisher = self.publisher.lock().await; + let waiter = publisher.replicate(&self.destination).await?; + Ok(ReplicationWaiter { + orchestrator: self.clone(), + waiter, + config: config(), + }) + } + + /// Request replication stop. + pub(crate) async fn request_stop(&self) { + self.publisher.lock().await.request_stop(); + } + + /// Perform the entire flow in one swoop. + pub(crate) async fn replicate_and_cutover(&mut self) -> Result<(), Error> { + // Load the schema from source. + self.load_schema().await?; + + // Sync the schema to destination. + self.schema_sync_pre(true).await?; + + // Sync the data to destination. + self.data_sync().await?; + + // Create secondary indexes on destination. + self.schema_sync_post(true).await?; + + // Start replication to catch up and cutover once done. + self.replicate().await?.cutover().await?; + + Ok(()) + } + + pub(crate) async fn schema_sync_post(&mut self, ignore_errors: bool) -> Result<(), Error> { + let schema = self.schema.as_ref().ok_or(Error::NoSchema)?; + + schema + .restore(&self.destination, ignore_errors, SyncState::PostData) + .await?; + + Ok(()) + } + + pub(crate) async fn schema_sync_cutover(&self, ignore_errors: bool) -> Result<(), Error> { + // Sequences won't be used in a sharded database. + let schema = self.schema.as_ref().ok_or(Error::NoSchema)?; + + schema + .restore(&self.destination, ignore_errors, SyncState::Cutover) + .await?; + + Ok(()) + } + + /// Get the largest replication lag out of all the shards. + async fn replication_lag(&self) -> u64 { + let lag = self.publisher.lock().await.replication_lag(); + lag.iter().map(|(_, lag)| *lag).max().unwrap_or_default() as u64 + } + + pub(crate) async fn cleanup(&mut self) -> Result<(), Error> { + let mut guard = self.publisher.lock().await; + guard.cleanup().await?; + + Ok(()) + } +} + +#[derive(Debug)] +pub struct ReplicationWaiter { + orchestrator: Orchestrator, + waiter: Waiter, + config: Arc, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum CutoverReason { + Lag, + Timeout, + LastTransaction, +} + +impl Display for CutoverReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Lag => write!(f, "lag"), + Self::Timeout => write!(f, "timeout"), + Self::LastTransaction => write!(f, "last_transaction"), + } + } +} + +impl ReplicationWaiter { + pub(crate) async fn wait(&mut self) -> Result<(), Error> { + self.waiter.wait().await + } + + pub(crate) fn stop(&self) { + self.waiter.stop(); + } + + /// Wait for replication to catch up. + async fn wait_for_replication(&mut self) -> Result<(), Error> { + let traffic_stop = self.config.config.general.cutover_traffic_stop_threshold; + + info!( + "[cutover] started, waiting for traffic stop threshold={}", + format_bytes(traffic_stop) + ); + + // Check once a second how far we got. + let mut check = interval(Duration::from_secs(1)); + + loop { + select! { + _ = check.tick() => {} + + // In case replication breaks now. + res = self.waiter.wait() => { + res?; + } + } + + let lag = self.orchestrator.replication_lag().await; + + info!("[cutover] replication lag: {}", format_bytes(lag as u64)); + + // Time to go. + if lag <= traffic_stop { + info!( + "[cutover] stopping traffic, lag={}, threshold={}", + format_bytes(lag), + format_bytes(traffic_stop), + ); + + // Pause traffic. + maintenance_mode::start(); + + // Cancel any running queries. + cancel_all(&self.orchestrator.source.identifier().database).await?; + + break; + // TODO: wait for clients to all stop. + } + } + + Ok(()) + } + + async fn should_cutover(&self, elapsed: Duration) -> Option { + let cutover_timeout = Duration::from_millis(self.config.config.general.cutover_timeout); + let cutover_threshold = self.config.config.general.cutover_replication_lag_threshold; + let last_transaction_delay = + Duration::from_millis(self.config.config.general.cutover_last_transaction_delay); + + let lag = self.orchestrator.replication_lag().await; + let last_transaction = self.orchestrator.publisher.lock().await.last_transaction(); + let cutover_timeout_exceeded = elapsed >= cutover_timeout; + + if cutover_timeout_exceeded { + Some(CutoverReason::Timeout) + } else if lag <= cutover_threshold { + Some(CutoverReason::Lag) + } else if last_transaction.map_or(true, |t| t > last_transaction_delay) { + Some(CutoverReason::LastTransaction) + } else { + None + } + } + + /// Wait for cutover. + async fn wait_for_cutover(&mut self) -> Result<(), Error> { + let cutover_threshold = self.config.config.general.cutover_replication_lag_threshold; + let last_transaction_delay = + Duration::from_millis(self.config.config.general.cutover_last_transaction_delay); + let cutover_timeout = Duration::from_millis(self.config.config.general.cutover_timeout); + let cutover_timeout_action = self.config.config.general.cutover_timeout_action; + + info!( + "[cutover] waiting for first cutover threshold: timeout={}, transaction={}, lag={}", + human_duration(cutover_timeout), + human_duration(last_transaction_delay), + format_bytes(cutover_threshold) + ); + + // Check more frequently. + let mut check = interval(Duration::from_millis(50)); + let mut log = interval(Duration::from_secs(1)); + // Abort clock starts now. + let start = Instant::now(); + + loop { + select! { + _ = check.tick() => {} + + _ = log.tick() => { + info!("[cutover] lag={}, last_transaction={}, timeout={}", + human_duration(cutover_timeout), + human_duration(last_transaction_delay), + format_bytes(cutover_threshold) + ); + } + + // In case replication breaks now. + res = self.waiter.wait() => { + res?; + } + } + + let elapsed = start.elapsed(); + let cutover_reason = self.should_cutover(elapsed).await; + match cutover_reason { + Some(CutoverReason::Timeout) => { + if cutover_timeout_action == CutoverTimeoutAction::Abort { + maintenance_mode::stop(); + warn!("[cutover] abort timeout reached, resuming traffic"); + return Err(Error::AbortTimeout); + } + } + + None => continue, + Some(reason) => { + info!("[cutover] performing cutover now, reason: {}", reason); + break; + } + } + } + + Ok(()) + } + + /// Perform traffic cutover between source and destination. + pub(crate) async fn cutover(&mut self) -> Result<(), Error> { + self.wait_for_replication().await?; + self.wait_for_cutover().await?; + + // We're going, point of no return. + self.orchestrator.publisher.lock().await.request_stop(); + ok_or_abort!(self.waiter.wait().await); + ok_or_abort!(self.orchestrator.schema_sync_cutover(true).await); + // Traffic is about to go to the new cluster. + // If this fails, we'll resume traffic to the old cluster instead + // and the whole thing needs to be done from scratch. + ok_or_abort!( + cutover( + &self.orchestrator.source.identifier().database, + &self.orchestrator.destination.identifier().database, + ) + .await + ); + + // Source is now destination and vice versa. + ok_or_abort!(self.orchestrator.refresh()); + + info!("[cutover] setting up reverse replication"); + + // Fix any reverse replication blockers. + ok_or_abort!(self.orchestrator.schema_sync_post_cutover(true).await); + + // Create reverse replication in case we need to rollback. + let waiter = ok_or_abort!(self.orchestrator.replicate().await); + + // Let it run in the background. + AsyncTasks::insert(TaskType::Replication(waiter)); + + // It's not safe to resume traffic. + info!("[cutover] complete, resuming traffic"); + + // Point traffic to the other database and resume. + maintenance_mode::stop(); + + Ok(()) + } +} + +macro_rules! ok_or_abort { + ($expr:expr) => { + match $expr { + Ok(res) => res, + Err(err) => { + maintenance_mode::stop(); + return Err(err.into()); + } + } + }; +} + +use ok_or_abort; + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::pool::Cluster; + use pgdog_config::ConfigAndUsers; + use std::sync::Arc; + use tokio::time::Instant; + + impl Orchestrator { + fn new_test(config: &ConfigAndUsers) -> Self { + let cluster = Cluster::new_test(config); + Self { + source: cluster.clone(), + destination: cluster, + publication: "test_pub".to_owned(), + schema: None, + publisher: Arc::new(Mutex::new(Publisher::default())), + replication_slot: "test_slot".to_owned(), + } + } + } + + impl ReplicationWaiter { + fn new_test(orchestrator: Orchestrator, config: Arc) -> Self { + Self { + orchestrator, + waiter: Waiter::new_test(), + config, + } + } + } + + #[tokio::test] + async fn test_wait_for_replication_exits_when_lag_below_threshold() { + // Ensure maintenance mode is off at start + maintenance_mode::stop(); + assert!(!maintenance_mode::is_on()); + + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_traffic_stop_threshold = 1000; + + let orchestrator = Orchestrator::new_test(&config); + + // Set replication lag below threshold + orchestrator + .publisher + .lock() + .await + .set_replication_lag(0, 500); + + let config = Arc::new(config); + let mut waiter = ReplicationWaiter::new_test(orchestrator, config); + + // Should exit immediately since lag (500) <= threshold (1000) + let result = waiter.wait_for_replication().await; + assert!(result.is_ok()); + + // Maintenance mode should be on after wait_for_replication + assert!(maintenance_mode::is_on()); + + // Clean up maintenance mode + maintenance_mode::stop(); + assert!(!maintenance_mode::is_on()); + } + + #[tokio::test] + async fn test_wait_for_cutover_exits_when_lag_below_threshold() { + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_replication_lag_threshold = 100; + config.config.general.cutover_timeout = 10000; + + let orchestrator = Orchestrator::new_test(&config); + + // Set replication lag below cutover threshold + orchestrator + .publisher + .lock() + .await + .set_replication_lag(0, 50); + + let config = Arc::new(config); + let mut waiter = ReplicationWaiter::new_test(orchestrator, config); + + // should_cutover returns Lag when lag is below threshold + let result = waiter.should_cutover(Duration::from_millis(100)).await; + assert_eq!(result, Some(CutoverReason::Lag)); + + // Should exit immediately since lag (50) <= threshold (100) + let result = waiter.wait_for_cutover().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_wait_for_cutover_exits_when_last_transaction_old() { + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_replication_lag_threshold = 10; + config.config.general.cutover_last_transaction_delay = 100; + config.config.general.cutover_timeout = 10000; + + let orchestrator = Orchestrator::new_test(&config); + + { + let publisher = orchestrator.publisher.lock().await; + // Set lag above threshold so we don't exit on that condition + publisher.set_replication_lag(0, 1000); + // Set last_transaction to a time in the past (> 100ms ago) + publisher.set_last_transaction(Some(Instant::now() - Duration::from_millis(200))); + } + + let config = Arc::new(config); + let mut waiter = ReplicationWaiter::new_test(orchestrator, config); + + // should_cutover returns LastTransaction when last transaction is old + let result = waiter.should_cutover(Duration::from_millis(100)).await; + assert_eq!(result, Some(CutoverReason::LastTransaction)); + + // Should exit because last_transaction (200ms) > threshold (100ms) + let result = waiter.wait_for_cutover().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_should_cutover_when_no_transaction() { + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_replication_lag_threshold = 10; + config.config.general.cutover_last_transaction_delay = 100; + config.config.general.cutover_timeout = 10000; + + let orchestrator = Orchestrator::new_test(&config); + + { + let publisher = orchestrator.publisher.lock().await; + // Set lag above threshold so we don't exit on that condition + publisher.set_replication_lag(0, 1000); + // No transaction set (None) + publisher.set_last_transaction(None); + } + + let config = Arc::new(config); + let waiter = ReplicationWaiter::new_test(orchestrator, config); + + // should_cutover returns LastTransaction when there's no transaction + let result = waiter.should_cutover(Duration::from_millis(100)).await; + assert_eq!(result, Some(CutoverReason::LastTransaction)); + } + + #[tokio::test] + async fn test_should_not_cutover_when_lag_above_threshold_and_recent_transaction() { + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_timeout = 10000; + config.config.general.cutover_replication_lag_threshold = 100; + config.config.general.cutover_last_transaction_delay = 500; + + let orchestrator = Orchestrator::new_test(&config); + + { + let publisher = orchestrator.publisher.lock().await; + // Lag above threshold + publisher.set_replication_lag(0, 1000); + // Recent transaction (50ms ago, threshold is 500ms) + publisher.set_last_transaction(Some(Instant::now() - Duration::from_millis(50))); + } + + let config = Arc::new(config); + let waiter = ReplicationWaiter::new_test(orchestrator, config); + + // Not timed out (100ms elapsed, timeout is 10000ms) + let result = waiter.should_cutover(Duration::from_millis(100)).await; + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_should_not_cutover_when_timeout_not_reached() { + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_timeout = 1000; + config.config.general.cutover_replication_lag_threshold = 10; + config.config.general.cutover_last_transaction_delay = 500; + + let orchestrator = Orchestrator::new_test(&config); + + { + let publisher = orchestrator.publisher.lock().await; + // Lag above threshold + publisher.set_replication_lag(0, 1000); + // Recent transaction + publisher.set_last_transaction(Some(Instant::now() - Duration::from_millis(100))); + } + + let config = Arc::new(config); + let waiter = ReplicationWaiter::new_test(orchestrator, config); + + // Elapsed is 999ms, timeout is 1000ms - should not trigger timeout + let result = waiter.should_cutover(Duration::from_millis(999)).await; + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_should_not_cutover_when_lag_just_above_threshold() { + let mut config = ConfigAndUsers::default(); + config.config.general.cutover_timeout = 10000; + config.config.general.cutover_replication_lag_threshold = 100; + config.config.general.cutover_last_transaction_delay = 500; + + let orchestrator = Orchestrator::new_test(&config); + + { + let publisher = orchestrator.publisher.lock().await; + // Lag just above threshold (101 > 100) + publisher.set_replication_lag(0, 101); + // Recent transaction + publisher.set_last_transaction(Some(Instant::now() - Duration::from_millis(50))); + } + + let config = Arc::new(config); + let waiter = ReplicationWaiter::new_test(orchestrator, config); + + let result = waiter.should_cutover(Duration::from_millis(100)).await; + assert_eq!(result, None); + } +} diff --git a/pgdog/src/backend/replication/logical/publisher/abort.rs b/pgdog/src/backend/replication/logical/publisher/abort.rs new file mode 100644 index 000000000..6b0e40c7f --- /dev/null +++ b/pgdog/src/backend/replication/logical/publisher/abort.rs @@ -0,0 +1,18 @@ +use tokio::sync::mpsc::UnboundedSender; + +use super::super::Error; +use super::*; + +pub struct AbortSignal { + tx: UnboundedSender>, +} + +impl AbortSignal { + pub fn new(tx: UnboundedSender>) -> Self { + Self { tx } + } + + pub async fn aborted(&self) { + self.tx.closed().await + } +} diff --git a/pgdog/src/backend/replication/logical/publisher/mod.rs b/pgdog/src/backend/replication/logical/publisher/mod.rs index ca57361eb..0185c8443 100644 --- a/pgdog/src/backend/replication/logical/publisher/mod.rs +++ b/pgdog/src/backend/replication/logical/publisher/mod.rs @@ -1,11 +1,13 @@ pub mod slot; pub use slot::*; +pub mod abort; pub mod copy; pub mod parallel_sync; pub mod progress; pub mod publisher_impl; pub mod queries; pub mod table; +pub use abort::*; pub use copy::*; pub use parallel_sync::ParallelSyncManager; pub use queries::*; diff --git a/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs b/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs index 809e17ce1..748c0095a 100644 --- a/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs +++ b/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs @@ -12,9 +12,15 @@ use tokio::{ Semaphore, }, }; +use tracing::info; use super::super::Error; -use crate::backend::{pool::Address, replication::publisher::Table, Cluster, Pool}; +use super::AbortSignal; +use crate::backend::{ + pool::Address, + replication::{publisher::Table, status::TableCopy}, + Cluster, Pool, +}; struct ParallelSync { table: Table, @@ -28,6 +34,9 @@ impl ParallelSync { // Run parallel sync. pub fn run(mut self) { spawn(async move { + // Record copy in queue before waiting for permit. + let tracker = TableCopy::new(&self.table.table.schema, &self.table.table.name); + // This won't acquire until we have at least 1 available permit. // Permit will be given back when this task completes. let _permit = self @@ -36,7 +45,17 @@ impl ParallelSync { .await .map_err(|_| Error::ParallelConnection)?; - let result = match self.table.data_sync(&self.addr, &self.dest).await { + if self.tx.is_closed() { + return Err(Error::DataSyncAborted); + } + + let abort = AbortSignal::new(self.tx.clone()); + + let result = match self + .table + .data_sync(&self.addr, &self.dest, abort, &tracker) + .await + { Ok(_) => Ok(self.table), Err(err) => Err(err), }; @@ -75,6 +94,11 @@ impl ParallelSyncManager { /// Run parallel table sync and return table LSNs when everything is done. pub async fn run(self) -> Result, Error> { + info!( + "starting parallel table copy using {} replicas", + self.replicas.len() + ); + let mut replicas_iter = self.replicas.iter(); // Loop through replicas, one at a time. // This works around Rust iterators not having a "rewind" function. @@ -103,8 +127,7 @@ impl ParallelSyncManager { drop(tx); while let Some(table) = rx.recv().await { - let table = table?; - tables.push(table); + tables.push(table?); } Ok(tables) diff --git a/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs b/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs index 79c97c9ad..02ebc93d1 100644 --- a/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs +++ b/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs @@ -1,9 +1,14 @@ use std::collections::HashMap; +use std::sync::Arc; use std::time::Duration; +use parking_lot::Mutex; use pgdog_config::QueryParserEngine; +use tokio::sync::Notify; +use tokio::task::JoinHandle; +use tokio::time::Instant; use tokio::{select, spawn, time::interval}; -use tracing::{debug, error, info}; +use tracing::{debug, info}; use super::super::{publisher::Table, Error}; use super::ReplicationSlot; @@ -18,7 +23,7 @@ use crate::backend::{pool::Request, Cluster}; use crate::config::Role; use crate::net::replication::ReplicationMeta; -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Publisher { /// Destination cluster. cluster: Cluster, @@ -30,6 +35,14 @@ pub struct Publisher { slots: HashMap, /// Query parser engine. query_parser_engine: QueryParserEngine, + /// Replication lag. + replication_lag: Arc>>, + /// Last transaction. + last_transaction: Arc>>, + /// Stop signal. + stop: Arc, + /// Slot name. + slot_name: String, } impl Publisher { @@ -37,6 +50,7 @@ impl Publisher { cluster: &Cluster, publication: &str, query_parser_engine: QueryParserEngine, + slot_name: String, ) -> Self { Self { cluster: cluster.clone(), @@ -44,9 +58,17 @@ impl Publisher { tables: HashMap::new(), slots: HashMap::new(), query_parser_engine, + replication_lag: Arc::new(Mutex::new(HashMap::new())), + stop: Arc::new(Notify::new()), + last_transaction: Arc::new(Mutex::new(None)), + slot_name, } } + pub fn replication_slot(&self) -> &str { + &self.slot_name + } + /// Synchronize tables for all shards. pub async fn sync_tables(&mut self) -> Result<(), Error> { for (number, shard) in self.cluster.shards().iter().enumerate() { @@ -68,12 +90,16 @@ impl Publisher { /// If you're doing a cross-shard transaction, parts of it can be lost. /// /// TODO: Add support for 2-phase commit. - async fn create_slots(&mut self, slot_name: Option) -> Result<(), Error> { + async fn create_slots(&mut self) -> Result<(), Error> { for (number, shard) in self.cluster.shards().iter().enumerate() { let addr = shard.primary(&Request::default()).await?.addr().clone(); - let mut slot = - ReplicationSlot::replication(&self.publication, &addr, slot_name.clone()); + let mut slot = ReplicationSlot::replication( + &self.publication, + &addr, + Some(self.slot_name.clone()), + number, + ); slot.create_slot().await?; self.slots.insert(number, slot); @@ -86,11 +112,7 @@ impl Publisher { /// /// This uses a dedicated replication slot which will survive crashes and reboots. /// N.B.: The slot needs to be manually dropped! - pub async fn replicate( - &mut self, - dest: &Cluster, - slot_name: Option, - ) -> Result<(), Error> { + pub async fn replicate(&mut self, dest: &Cluster) -> Result { // Replicate shards in parallel. let mut streams = vec![]; @@ -101,7 +123,7 @@ impl Publisher { // Create replication slots if we haven't already. if self.slots.is_empty() { - self.create_slots(slot_name).await?; + self.create_slots().await?; } for (number, _) in self.cluster.shards().iter().enumerate() { @@ -122,6 +144,9 @@ impl Publisher { stream.set_current_lsn(slot.lsn().lsn); let mut check_lag = interval(Duration::from_secs(1)); + let replication_lag = self.replication_lag.clone(); + let stop = self.stop.clone(); + let last_transaction = self.last_transaction.clone(); // Replicate in parallel. let handle = spawn(async move { @@ -130,6 +155,10 @@ impl Publisher { loop { select! { + _ = stop.notified() => { + slot.stop_replication().await?; + } + // This is cancellation-safe. replication_data = slot.replicate(Duration::MAX) => { let replication_data = replication_data?; @@ -151,6 +180,7 @@ impl Publisher { } else { if let Some(status_update) = stream.handle(data).await? { slot.status_update(status_update).await?; + *last_transaction.lock() = Some(Instant::now()); } stream.lsn() }; @@ -167,11 +197,10 @@ impl Publisher { _ = check_lag.tick() => { let lag = slot.replication_lag().await?; - info!( - "replication lag at {} bytes [{}]", - lag, - slot.server()?.addr() - ); + let mut guard = replication_lag.lock(); + guard.insert(number, lag); + + } } } @@ -182,34 +211,49 @@ impl Publisher { streams.push(handle); } - for (shard, stream) in streams.into_iter().enumerate() { - if let Err(err) = stream.await.unwrap() { - error!("error replicating from shard {}: {}", shard, err); - return Err(err); - } - } + Ok(Waiter { + streams, + stop: self.stop.clone(), + }) + } - Ok(()) + /// Request the publisher to stop replication. + pub fn request_stop(&self) { + self.stop.notify_one(); + } + + /// Get current replication lag. + pub fn replication_lag(&self) -> HashMap { + self.replication_lag.lock().clone() + } + + /// Get how long ago last transaction was committed. + pub fn last_transaction(&self) -> Option { + self.last_transaction + .lock() + .clone() + .map(|last| last.elapsed()) } /// Sync data from all tables in a publication from one shard to N shards, /// re-sharding the cluster in the process. /// /// TODO: Parallelize shard syncs. - pub async fn data_sync( - &mut self, - dest: &Cluster, - replicate: bool, - slot_name: Option, - ) -> Result<(), Error> { + pub async fn data_sync(&mut self, dest: &Cluster) -> Result<(), Error> { // Create replication slots. - self.create_slots(slot_name.clone()).await?; + self.create_slots().await?; for (number, shard) in self.cluster.shards().iter().enumerate() { let mut primary = shard.primary(&Request::default()).await?; let tables = Table::load(&self.publication, &mut primary, self.query_parser_engine).await?; + info!( + "table sync starting for {} tables, shard={}", + tables.len(), + number + ); + let include_primary = !shard.has_replicas(); let resharding_only = shard .pools() @@ -245,11 +289,56 @@ impl Publisher { self.tables.insert(number, tables); } - if replicate { - // Replicate changes. - self.replicate(dest, slot_name).await?; + Ok(()) + } + + /// Cleanup after replication. + pub async fn cleanup(&mut self) -> Result<(), Error> { + for slot in self.slots.values_mut() { + slot.drop_slot().await?; } Ok(()) } } + +#[cfg(test)] +impl Publisher { + pub fn set_replication_lag(&self, shard: usize, lag: i64) { + self.replication_lag.lock().insert(shard, lag); + } + + pub fn set_last_transaction(&self, instant: Option) { + *self.last_transaction.lock() = instant; + } +} + +#[derive(Debug)] +pub struct Waiter { + streams: Vec>>, + stop: Arc, +} + +impl Waiter { + pub fn stop(&self) { + self.stop.notify_one(); + } + + pub async fn wait(&mut self) -> Result<(), Error> { + for stream in &mut self.streams { + stream.await??; + } + + Ok(()) + } +} + +#[cfg(test)] +impl Waiter { + pub fn new_test() -> Self { + Self { + streams: vec![], + stop: Arc::new(Notify::new()), + } + } +} diff --git a/pgdog/src/backend/replication/logical/publisher/queries.rs b/pgdog/src/backend/replication/logical/publisher/queries.rs index fef0e5c9b..724ed664f 100644 --- a/pgdog/src/backend/replication/logical/publisher/queries.rs +++ b/pgdog/src/backend/replication/logical/publisher/queries.rs @@ -33,7 +33,7 @@ LEFT JOIN pg_namespace pn ON pn.oid = p.relnamespace ORDER BY n.nspname, c.relname;"; /// Table included in a publication. -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq, Default, Eq, Hash)] pub struct PublicationTable { pub schema: String, pub name: String, @@ -99,7 +99,7 @@ INNER JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid) WHERE n.nspname = $1 AND c.relname = $2"; /// Identifies the columns part of the replica identity for a table. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct ReplicaIdentity { pub oid: i32, pub identity: String, @@ -146,7 +146,7 @@ FROM ON (i.indexrelid = pg_get_replica_identity_index($1)) WHERE a.attnum > 0::pg_catalog.int2 AND NOT a.attisdropped AND a.attgenerated = '' AND a.attrelid = $2 ORDER BY a.attnum"; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PublicationTableColumn { pub oid: i32, pub name: String, diff --git a/pgdog/src/backend/replication/logical/publisher/slot.rs b/pgdog/src/backend/replication/logical/publisher/slot.rs index afbe51832..0c095e9e5 100644 --- a/pgdog/src/backend/replication/logical/publisher/slot.rs +++ b/pgdog/src/backend/replication/logical/publisher/slot.rs @@ -1,3 +1,4 @@ +use super::super::status::ReplicationSlot as ReplicationSlotTracker; use super::super::Error; use crate::{ backend::{self, pool::Address, ConnectReason, Server, ServerOptions}, @@ -9,7 +10,7 @@ use crate::{ }; use std::{fmt::Display, str::FromStr, time::Duration}; use tokio::time::timeout; -use tracing::{debug, trace}; +use tracing::{debug, info, trace, warn}; pub use pgdog_stats::Lsn; @@ -47,12 +48,19 @@ pub struct ReplicationSlot { server: Option, kind: SlotKind, server_meta: Option, + tracker: Option, } impl ReplicationSlot { /// Create replication slot used for streaming the WAL. - pub fn replication(publication: &str, address: &Address, name: Option) -> Self { - let name = name.unwrap_or(format!("__pgdog_repl_{}", random_string(19).to_lowercase())); + pub fn replication( + publication: &str, + address: &Address, + name: Option, + shard: usize, + ) -> Self { + let name = name.unwrap_or(format!("__pgdog_repl_{}", random_string(18).to_lowercase())); + let name = format!("{}_{}", name, shard); Self { address: address.clone(), @@ -64,6 +72,7 @@ impl ReplicationSlot { server: None, kind: SlotKind::Replication, server_meta: None, + tracker: None, } } @@ -81,6 +90,7 @@ impl ReplicationSlot { server: None, kind: SlotKind::DataSync, server_meta: None, + tracker: None, } } @@ -125,8 +135,15 @@ impl ReplicationSlot { ); let mut lag: Vec = self.server_meta().await?.fetch_all(&query).await?; - lag.pop() - .ok_or(Error::MissingReplicationSlot(self.name.clone())) + let lag = lag + .pop() + .ok_or(Error::MissingReplicationSlot(self.name.clone()))?; + + if let Some(ref tracker) = self.tracker { + tracker.update_lag(lag); + } + + Ok(lag) } pub fn server(&mut self) -> Result<&mut Server, Error> { @@ -139,6 +156,11 @@ impl ReplicationSlot { self.connect().await?; } + info!( + "creating replication slot \"{}\" [{}]", + self.name, self.address + ); + if self.kind == SlotKind::DataSync { self.server()? .execute("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ") @@ -180,8 +202,14 @@ impl ReplicationSlot { .ok_or(Error::MissingData)?; let lsn = Lsn::from_str(&lsn)?; self.lsn = lsn; + self.tracker = Some(ReplicationSlotTracker::new( + &self.name, + &self.lsn, + self.dropped, + &self.address, + )); - debug!( + info!( "replication slot \"{}\" at lsn {} created [{}]", self.name, self.lsn, self.address, ); @@ -201,8 +229,14 @@ impl ReplicationSlot { { let lsn = Lsn::from_str(&lsn)?; self.lsn = lsn; - - debug!( + self.tracker = Some(ReplicationSlotTracker::new( + &self.name, + &self.lsn, + self.dropped, + &self.address, + )); + + info!( "using existing replication slot \"{}\" at lsn {} [{}]", self.name, self.lsn, self.address, ); @@ -226,11 +260,12 @@ impl ReplicationSlot { let drop_slot = self.drop_slot_query(true); self.server()?.execute(&drop_slot).await?; - debug!( + warn!( "replication slot \"{}\" dropped [{}]", self.name, self.address ); self.dropped = true; + self.tracker.take().map(|slot| slot.dropped()); Ok(()) } @@ -306,6 +341,10 @@ impl ReplicationSlot { self.server()?.addr() ); + self.tracker + .as_ref() + .map(|tracker| tracker.update_lsn(&Lsn::from_i64(status_update.last_flushed))); + self.server()? .send_one(&status_update.wrapped()?.into()) .await?; @@ -326,6 +365,11 @@ impl ReplicationSlot { pub fn lsn(&self) -> Lsn { self.lsn } + + /// Slot name. + pub fn name(&self) -> &str { + &self.name + } } #[derive(Debug, Clone)] @@ -406,6 +450,7 @@ mod test { "test_slot_replication", addr, Some("test_slot_replication".into()), + 0, ); let _ = slot.create_slot().await.unwrap(); slot.connect().await.unwrap(); diff --git a/pgdog/src/backend/replication/logical/publisher/table.rs b/pgdog/src/backend/replication/logical/publisher/table.rs index 7152ebaf7..29f5db307 100644 --- a/pgdog/src/backend/replication/logical/publisher/table.rs +++ b/pgdog/src/backend/replication/logical/publisher/table.rs @@ -3,21 +3,27 @@ use std::time::Duration; use pgdog_config::QueryParserEngine; +use tokio::select; +use tracing::error; use crate::backend::pool::Address; use crate::backend::replication::publisher::progress::Progress; use crate::backend::replication::publisher::Lsn; +use crate::backend::replication::status::TableCopy; use crate::backend::{Cluster, Server}; use crate::config::config; use crate::net::replication::StatusUpdate; +use crate::util::escape_identifier; use super::super::{subscriber::CopySubscriber, Error}; -use super::{Copy, PublicationTable, PublicationTableColumn, ReplicaIdentity, ReplicationSlot}; +use super::{ + AbortSignal, Copy, PublicationTable, PublicationTableColumn, ReplicaIdentity, ReplicationSlot, +}; use tracing::info; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Table { /// Name of the table publication. pub publication: String, @@ -74,7 +80,7 @@ impl Table { "({})", self.columns .iter() - .map(|c| format!("\"{}\"", c.name.as_str())) + .map(|c| format!("\"{}\"", escape_identifier(&c.name))) .collect::>() .join(", ") ); @@ -93,14 +99,14 @@ impl Table { self.columns .iter() .filter(|c| c.identity) - .map(|c| format!("\"{}\"", c.name.as_str())) + .map(|c| format!("\"{}\"", escape_identifier(&c.name))) .collect::>() .join(", "), self.columns .iter() .enumerate() .filter(|(_, c)| !c.identity) - .map(|(i, c)| format!("\"{}\" = ${}", c.name, i + 1)) + .map(|(i, c)| format!("\"{}\" = ${}", escape_identifier(&c.name), i + 1)) .collect::>() .join(", ") ) @@ -110,8 +116,8 @@ impl Table { format!( "INSERT INTO \"{}\".\"{}\" {} {} {}", - self.table.destination_schema(), - self.table.destination_name(), + escape_identifier(self.table.destination_schema()), + escape_identifier(self.table.destination_name()), names, values, on_conflict @@ -125,7 +131,7 @@ impl Table { .iter() .enumerate() .filter(|(_, c)| !c.identity) - .map(|(i, c)| format!("\"{}\" = ${}", c.name, i + 1)) + .map(|(i, c)| format!("\"{}\" = ${}", escape_identifier(&c.name), i + 1)) .collect::>() .join(", "); @@ -134,14 +140,14 @@ impl Table { .iter() .enumerate() .filter(|(_, c)| c.identity) - .map(|(i, c)| format!("\"{}\" = ${}", c.name, i + 1)) + .map(|(i, c)| format!("\"{}\" = ${}", escape_identifier(&c.name), i + 1)) .collect::>() .join(" AND "); format!( "UPDATE \"{}\".\"{}\" SET {} WHERE {}", - self.table.destination_schema(), - self.table.destination_name(), + escape_identifier(self.table.destination_schema()), + escape_identifier(self.table.destination_name()), set_clause, where_clause ) @@ -154,14 +160,14 @@ impl Table { .iter() .enumerate() .filter(|(_, c)| c.identity) - .map(|(i, c)| format!("\"{}\" = ${}", c.name, i + 1)) + .map(|(i, c)| format!("\"{}\" = ${}", escape_identifier(&c.name), i + 1)) .collect::>() .join(" AND "); format!( "DELETE FROM \"{}\".\"{}\" WHERE {}", - self.table.destination_schema(), - self.table.destination_name(), + escape_identifier(self.table.destination_schema()), + escape_identifier(self.table.destination_name()), where_clause ) } @@ -178,7 +184,13 @@ impl Table { Ok(()) } - pub async fn data_sync(&mut self, source: &Address, dest: &Cluster) -> Result { + pub async fn data_sync( + &mut self, + source: &Address, + dest: &Cluster, + abort: AbortSignal, + tracker: &TableCopy, + ) -> Result { info!( "data sync for \"{}\".\"{}\" started [{}]", self.table.schema, self.table.name, source @@ -189,6 +201,8 @@ impl Table { // Subscriber uses COPY [...] FROM STDIN. let copy = Copy::new(self, config().config.general.resharding_copy_format); + tracker.update_sql(©.statement().copy_out()); + // Create new standalone connection for the copy. // let mut server = Server::connect(source, ServerOptions::new_replication()).await?; let mut copy_sub = CopySubscriber::new(copy.statement(), dest, self.query_parser_engine)?; @@ -208,8 +222,19 @@ impl Table { let progress = Progress::new_data_sync(&self.table); while let Some(data_row) = copy.data(slot.server()?).await? { - copy_sub.copy_data(data_row).await?; - progress.update(copy_sub.bytes_sharded(), slot.lsn().lsn); + select! { + _ = abort.aborted() => { + error!("aborting data sync for table {}", self.table); + + return Err(Error::CopyAborted(self.table.clone())) + }, + + result = copy_sub.copy_data(data_row) => { + let (rows, bytes) = result?; + progress.update(copy_sub.bytes_sharded(), slot.lsn().lsn); + tracker.update_progress(bytes, rows); + } + } } copy_sub.copy_done().await?; @@ -246,6 +271,109 @@ mod test { use super::*; + fn make_table(columns: Vec<(&str, bool)>) -> Table { + Table { + publication: "test".to_string(), + table: PublicationTable { + schema: "public".to_string(), + name: "test_table".to_string(), + attributes: "".to_string(), + parent_schema: "".to_string(), + parent_name: "".to_string(), + }, + identity: ReplicaIdentity { + oid: 1, + identity: "".to_string(), + kind: "".to_string(), + }, + columns: columns + .into_iter() + .map(|(name, identity)| PublicationTableColumn { + oid: 1, + name: name.to_string(), + type_oid: 23, + identity, + }) + .collect(), + lsn: Lsn::default(), + query_parser_engine: QueryParserEngine::default(), + } + } + + #[test] + fn test_sql_generation_simple() { + let table = make_table(vec![("id", true), ("name", false), ("value", false)]); + + let insert = table.insert(false); + assert!(pg_query::parse(&insert).is_ok(), "insert: {}", insert); + + let upsert = table.insert(true); + assert!(pg_query::parse(&upsert).is_ok(), "upsert: {}", upsert); + + let update = table.update(); + assert!(pg_query::parse(&update).is_ok(), "update: {}", update); + + let delete = table.delete(); + assert!(pg_query::parse(&delete).is_ok(), "delete: {}", delete); + } + + #[test] + fn test_sql_generation_quoted_column() { + let table = make_table(vec![("id", true), ("has\"quote", false), ("normal", false)]); + + let insert = table.insert(false); + assert!(pg_query::parse(&insert).is_ok(), "insert: {}", insert); + + let upsert = table.insert(true); + assert!(pg_query::parse(&upsert).is_ok(), "upsert: {}", upsert); + + let update = table.update(); + assert!(pg_query::parse(&update).is_ok(), "update: {}", update); + + let delete = table.delete(); + assert!(pg_query::parse(&delete).is_ok(), "delete: {}", delete); + } + + #[test] + fn test_sql_generation_special_chars() { + let table = make_table(vec![ + ("id", true), + ("col with spaces", false), + ("UPPER", false), + ]); + + let insert = table.insert(false); + assert!(pg_query::parse(&insert).is_ok(), "insert: {}", insert); + + let upsert = table.insert(true); + assert!(pg_query::parse(&upsert).is_ok(), "upsert: {}", upsert); + + let update = table.update(); + assert!(pg_query::parse(&update).is_ok(), "update: {}", update); + + let delete = table.delete(); + assert!(pg_query::parse(&delete).is_ok(), "delete: {}", delete); + } + + #[test] + fn test_sql_generation_quoted_table_name() { + let mut table = make_table(vec![("id", true), ("value", false)]); + table.table.name = "table\"with\"quotes".to_string(); + table.table.schema = "schema\"quote".to_string(); + + let insert = table.insert(false); + assert!(pg_query::parse(&insert).is_ok(), "insert: {}", insert); + + let upsert = table.insert(true); + assert!(pg_query::parse(&upsert).is_ok(), "upsert: {}", upsert); + + let update = table.update(); + assert!(pg_query::parse(&update).is_ok(), "update: {}", update); + + let delete = table.delete(); + assert!(pg_query::parse(&delete).is_ok(), "delete: {}", delete); + } + #[tokio::test] async fn test_publication() { crate::logger(); diff --git a/pgdog/src/backend/replication/logical/status.rs b/pgdog/src/backend/replication/logical/status.rs new file mode 100644 index 000000000..1ef701365 --- /dev/null +++ b/pgdog/src/backend/replication/logical/status.rs @@ -0,0 +1,302 @@ +use std::{ops::Deref, sync::Arc, time::SystemTime}; + +use dashmap::{DashMap, DashSet}; +use once_cell::sync::Lazy; +use pgdog_stats::Lsn; + +use crate::backend::{ + databases::User, + pool::Address, + schema::sync::{Statement, SyncState}, + Cluster, +}; + +/// Status of table copies. +static COPIES: Lazy = Lazy::new(|| TableCopies::default()); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TableCopy { + pub(crate) schema: String, + pub(crate) table: String, +} + +impl TableCopy { + pub(crate) fn new(schema: &str, table: &str) -> Self { + let copy = Self { + schema: schema.to_owned(), + table: table.to_owned(), + }; + TableCopies::get().insert( + copy.clone(), + TableCopyState { + last_update: SystemTime::now(), + ..Default::default() + }, + ); + copy + } + + pub(crate) fn update_progress(&self, bytes: usize, rows: usize) { + if let Some(mut state) = TableCopies::get().get_mut(self) { + state.bytes += bytes; + state.rows += rows; + let elapsed = SystemTime::now() + .duration_since(state.last_update) + .unwrap_or_default() + .as_secs(); + if elapsed > 0 { + state.bytes_per_sec = state.bytes / elapsed as usize; + } + } + } + + pub(crate) fn update_sql(&self, sql: &str) { + if let Some(mut state) = TableCopies::get().get_mut(self) { + state.sql = sql.to_owned(); + } + } +} + +impl Drop for TableCopy { + fn drop(&mut self) { + COPIES.copies.remove(self); + } +} + +#[derive(Debug, Clone)] +pub struct TableCopyState { + pub(crate) sql: String, + pub(crate) rows: usize, + pub(crate) bytes: usize, + pub(crate) bytes_per_sec: usize, + pub(crate) last_update: SystemTime, +} + +impl Default for TableCopyState { + fn default() -> Self { + Self { + sql: String::default(), + rows: 0, + bytes: 0, + bytes_per_sec: 0, + last_update: SystemTime::now(), + } + } +} + +#[derive(Default, Clone)] +pub struct TableCopies { + copies: Arc>, +} + +impl Deref for TableCopies { + type Target = DashMap; + + fn deref(&self) -> &Self::Target { + &self.copies + } +} + +impl TableCopies { + pub(crate) fn get() -> Self { + COPIES.clone() + } +} + +static REPLICATION_SLOTS: Lazy = Lazy::new(ReplicationSlots::default); + +/// Replication slot. +#[derive(Debug, Clone)] +pub struct ReplicationSlot { + pub(crate) name: String, + pub(crate) lsn: Lsn, + pub(crate) lag: i64, + pub(crate) copy_data: bool, + pub(crate) address: Address, + pub(crate) last_transaction: Option, +} + +impl ReplicationSlot { + pub(crate) fn new(name: &str, lsn: &Lsn, copy_data: bool, address: &Address) -> Self { + let slot = Self { + name: name.to_owned(), + lsn: lsn.clone(), + copy_data, + lag: 0, + address: address.clone(), + last_transaction: None, + }; + + ReplicationSlots::get().insert(name.to_owned(), slot.clone()); + + slot + } + + pub(crate) fn update_lsn(&self, lsn: &Lsn) { + if let Some(mut slot) = ReplicationSlots::get().get_mut(&self.name) { + slot.lsn = lsn.clone(); + slot.last_transaction = Some(SystemTime::now()); + } + } + + pub(crate) fn update_lag(&self, lag: i64) { + if let Some(mut slot) = ReplicationSlots::get().get_mut(&self.name) { + slot.lag = lag; + } + } + + pub(crate) fn dropped(&self) { + ReplicationSlots::get().remove(&self.name); + } +} + +impl Drop for ReplicationSlot { + fn drop(&mut self) { + if self.copy_data { + self.dropped(); + } + } +} + +#[derive(Default, Clone, Debug)] +pub struct ReplicationSlots { + slots: Arc>, +} + +impl ReplicationSlots { + pub(crate) fn get() -> Self { + REPLICATION_SLOTS.clone() + } +} + +impl Deref for ReplicationSlots { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { + &self.slots + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum StatementKind { + Table, + Index, + Statement, +} + +impl std::fmt::Display for StatementKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Table => write!(f, "table"), + Self::Index => write!(f, "index"), + Self::Statement => write!(f, "statement"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Hash, Eq)] +pub struct SchemaStatement { + pub(crate) user: User, + pub(crate) shard: usize, + pub(crate) sql: String, + pub(crate) kind: StatementKind, + pub(crate) sync_state: SyncState, + pub(crate) started_at: SystemTime, + pub(crate) table_schema: Option, + pub(crate) table_name: Option, +} + +impl SchemaStatement { + pub(crate) fn new( + cluster: &Cluster, + stmt: &Statement<'_>, + shard: usize, + sync_state: SyncState, + ) -> Self { + let user = cluster.identifier().deref().clone(); + + let stmt = match stmt { + Statement::Index { table, sql, .. } => Self { + user, + shard, + sql: sql.clone(), + kind: StatementKind::Index, + sync_state, + started_at: SystemTime::now(), + table_schema: table.schema.map(|s| s.to_string()), + table_name: Some(table.name.to_owned()), + }, + Statement::Table { table, sql } => Self { + user, + shard, + sql: sql.clone(), + kind: StatementKind::Table, + sync_state, + started_at: SystemTime::now(), + table_schema: table.schema.map(|s| s.to_string()), + table_name: Some(table.name.to_owned()), + }, + Statement::Other { sql, .. } => Self { + user, + shard, + sql: sql.clone(), + kind: StatementKind::Statement, + sync_state, + started_at: SystemTime::now(), + table_schema: None, + table_name: None, + }, + Statement::SequenceOwner { sql, .. } => Self { + user, + shard, + sql: sql.to_string(), + kind: StatementKind::Statement, + sync_state, + started_at: SystemTime::now(), + table_schema: None, + table_name: None, + }, + Statement::SequenceSetMax { sql, .. } => Self { + user, + shard, + sql: sql.clone(), + kind: StatementKind::Statement, + sync_state, + started_at: SystemTime::now(), + table_schema: None, + table_name: None, + }, + }; + + SchemaStatements::get().insert(stmt.clone()); + + stmt + } +} + +impl Drop for SchemaStatement { + fn drop(&mut self) { + SchemaStatements::get().remove(self); + } +} + +#[derive(Default, Debug, Clone)] +pub struct SchemaStatements { + stmts: Arc>, +} + +impl SchemaStatements { + pub(crate) fn get() -> Self { + SCHEMA_STATEMENTS.clone() + } +} + +impl Deref for SchemaStatements { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { + &self.stmts + } +} + +static SCHEMA_STATEMENTS: Lazy = Lazy::new(SchemaStatements::default); diff --git a/pgdog/src/backend/replication/logical/subscriber/copy.rs b/pgdog/src/backend/replication/logical/subscriber/copy.rs index 28741b152..5c22ee4da 100644 --- a/pgdog/src/backend/replication/logical/subscriber/copy.rs +++ b/pgdog/src/backend/replication/logical/subscriber/copy.rs @@ -160,19 +160,22 @@ impl CopySubscriber { } /// Send data to subscriber, buffered. - pub async fn copy_data(&mut self, data: CopyData) -> Result<(), Error> { + pub async fn copy_data(&mut self, data: CopyData) -> Result<(usize, usize), Error> { self.buffer.push(data); if self.buffer.len() == BUFFER_SIZE { - self.flush().await? + return self.flush().await; } - Ok(()) + Ok((0, 0)) } - async fn flush(&mut self) -> Result<(), Error> { + async fn flush(&mut self) -> Result<(usize, usize), Error> { let result = self.copy.shard(&self.buffer)?; self.buffer.clear(); + let rows = result.len(); + let bytes = result.iter().map(|row| row.len()).sum::(); + for row in &result { for (shard, server) in self.connections.iter_mut().enumerate() { match row.shard() { @@ -193,7 +196,7 @@ impl CopySubscriber { self.bytes_sharded += result.iter().map(|c| c.len()).sum::(); - Ok(()) + Ok((rows, bytes)) } /// Total amount of bytes shaded. diff --git a/pgdog/src/backend/schema/sync/mod.rs b/pgdog/src/backend/schema/sync/mod.rs index 29c97f673..56cf942f5 100644 --- a/pgdog/src/backend/schema/sync/mod.rs +++ b/pgdog/src/backend/schema/sync/mod.rs @@ -3,5 +3,6 @@ pub mod error; pub mod pg_dump; pub mod progress; +pub use config::ShardConfig; pub use error::Error; -pub use pg_dump::Statement; +pub use pg_dump::{PgDump, Statement, SyncState}; diff --git a/pgdog/src/backend/schema/sync/pg_dump.rs b/pgdog/src/backend/schema/sync/pg_dump.rs index 9fd639bc9..ba6f80807 100644 --- a/pgdog/src/backend/schema/sync/pg_dump.rs +++ b/pgdog/src/backend/schema/sync/pg_dump.rs @@ -17,7 +17,12 @@ use tracing::{info, trace, warn}; use super::{progress::Progress, Error}; use crate::{ - backend::{self, pool::Request, replication::publisher::PublicationTable, Cluster}, + backend::{ + self, + pool::Request, + replication::{publisher::PublicationTable, status::SchemaStatement}, + Cluster, + }, config::config, frontend::router::parser::{sequence::Sequence, Column, Table}, }; @@ -220,11 +225,23 @@ pub struct PgDumpOutput { original: String, } -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum SyncState { PreData, PostData, Cutover, + PostCutover, +} + +impl std::fmt::Display for SyncState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::PreData => write!(f, "pre_data"), + Self::PostData => write!(f, "post_data"), + Self::Cutover => write!(f, "cutover"), + Self::PostCutover => write!(f, "post_cutover"), + } + } } #[derive(Debug)] @@ -809,6 +826,10 @@ impl PgDumpOutput { AlterTableType::AtAddIdentity => { if state == SyncState::Cutover { + // Add identity constraint during cutover + result.push(original.into()); + + // Set sequence to max(column) value if let Some(ref node) = cmd.def { if let Some(NodeEnum::Constraint( ref constraint, @@ -863,9 +884,19 @@ impl PgDumpOutput { } } } - } else if state == SyncState::PreData { - result.push(original.into()); + } else if state == SyncState::PostCutover { + // Drop identity constraint after cutover + if let Some(ref relation) = stmt.relation { + let sql = format!( + "ALTER TABLE \"{}\".\"{}\" ALTER COLUMN \"{}\" DROP IDENTITY IF EXISTS", + crate::util::escape_identifier(schema_name(relation)), + crate::util::escape_identifier(&relation.relname), + crate::util::escape_identifier(&cmd.name) + ); + result.push(sql.into()); + } } + // Skip identity constraint in PreData - it will be added in Cutover } // AlterTableType::AtChangeOwner => { // continue; // Don't change owners, for now. @@ -889,10 +920,34 @@ impl PgDumpOutput { } } + NodeEnum::CreatePublicationStmt(stmt) => { + if state == SyncState::PreData { + // DROP first for idempotency + result.push(Statement::Other { + sql: format!( + "DROP PUBLICATION IF EXISTS \"{}\"", + crate::util::escape_identifier(&stmt.pubname) + ), + idempotent: true, + }); + result.push(Statement::Other { + sql: original.to_string(), + idempotent: false, + }); + } + } + + NodeEnum::AlterPublicationStmt(_) => { + if state == SyncState::PreData { + result.push(Statement::Other { + sql: original.to_string(), + idempotent: false, + }); + } + } + // Skip these. - NodeEnum::CreatePublicationStmt(_) - | NodeEnum::CreateSubscriptionStmt(_) - | NodeEnum::AlterPublicationStmt(_) + NodeEnum::CreateSubscriptionStmt(_) | NodeEnum::AlterSubscriptionStmt(_) => (), NodeEnum::AlterSeqStmt(stmt) => { @@ -938,6 +993,19 @@ impl PgDumpOutput { let table = stmt.relation.as_ref().map(Table::from).unwrap_or_default(); + let index_schema = stmt + .relation + .as_ref() + .map(|r| schema_name(r)) + .unwrap_or("public"); + result.push(Statement::Other { + sql: format!( + "DROP INDEX IF EXISTS \"{}\".\"{}\"", + index_schema, stmt.idxname + ), + idempotent: true, + }); + result.push(Statement::Index { table, name: stmt.idxname.as_str(), @@ -1042,6 +1110,8 @@ impl PgDumpOutput { for stmt in &stmts { progress.next(stmt); + let _tracker = SchemaStatement::new(dest, stmt, num, state); + if let Err(err) = primary.execute(stmt.deref()).await { if let backend::Error::ExecutionError(ref err) = err { let code = &err.code; @@ -1158,9 +1228,23 @@ ALTER TABLE ONLY public.users stmts: parse.protobuf, original: q.to_string(), }; + + // Identity constraints should be skipped in PreData + let statements = output.statements(SyncState::PreData).unwrap(); + assert!(statements.is_empty()); + + // Identity constraints should be added in Cutover let statements = output.statements(SyncState::Cutover).unwrap(); - match statements.first() { - Some(Statement::SequenceSetMax { sequence, sql }) => { + assert_eq!(statements.len(), 2); + + // First statement is the original identity constraint + assert!(statements[0] + .deref() + .contains("ADD GENERATED ALWAYS AS IDENTITY")); + + // Second statement is the sequence setval + match &statements[1] { + Statement::SequenceSetMax { sequence, sql } => { assert_eq!(sequence.table.name, "users_id_seq"); assert_eq!( sequence.table.schema().map(|schema| schema.name), @@ -1171,15 +1255,38 @@ ALTER TABLE ONLY public.users r#"SELECT setval('"public"."users_id_seq"', COALESCE((SELECT MAX("id") FROM "public"."users"), 1), true);"# ); } - _ => panic!("not a set sequence max"), } - let statements = output.statements(SyncState::PreData).unwrap(); - assert!(!statements.is_empty()); + let statements = output.statements(SyncState::PostData).unwrap(); assert!(statements.is_empty()); } + #[test] + fn test_generated_identity_post_cutover() { + let q = "ALTER TABLE public.users ALTER COLUMN id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME public.users_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 + );"; + let parse = pg_query::parse(q).unwrap(); + let output = PgDumpOutput { + stmts: parse.protobuf, + original: q.to_string(), + }; + + // PostCutover should drop identity constraints + let statements = output.statements(SyncState::PostCutover).unwrap(); + assert_eq!(statements.len(), 1); + assert!(statements[0].deref().contains("DROP IDENTITY IF EXISTS")); + assert!(statements[0].deref().contains("public")); + assert!(statements[0].deref().contains("users")); + assert!(statements[0].deref().contains("id")); + } + #[test] fn test_integer_primary_key_columns() { let query = r#" @@ -1382,6 +1489,45 @@ ALTER TABLE ONLY parent ATTACH PARTITION parent_2024 FOR VALUES FROM ('2024-01-0 assert!(post_data.is_empty()); } + #[test] + fn test_create_publication_restored() { + let q = "CREATE PUBLICATION my_pub FOR TABLE users, orders;"; + let output = PgDumpOutput { + stmts: parse(q).unwrap().protobuf, + original: q.to_owned(), + }; + + let statements = output.statements(SyncState::PreData).unwrap(); + + // Should have DROP and CREATE statements + assert_eq!(statements.len(), 2); + assert_eq!( + statements[0].deref(), + "DROP PUBLICATION IF EXISTS \"my_pub\"" + ); + assert_eq!( + statements[1].deref(), + "CREATE PUBLICATION my_pub FOR TABLE users, orders" + ); + } + + #[test] + fn test_alter_publication_add_table_restored() { + // pg_dump outputs publication tables as ALTER PUBLICATION ... ADD TABLE + let q = "ALTER PUBLICATION my_pub ADD TABLE ONLY public.users;"; + let output = PgDumpOutput { + stmts: parse(q).unwrap().protobuf, + original: q.to_owned(), + }; + + let statements = output.statements(SyncState::PreData).unwrap(); + + assert_eq!(statements.len(), 1); + assert!(statements[0] + .deref() + .contains("ALTER PUBLICATION my_pub ADD TABLE")); + } + #[test] fn test_partitioned_child_inherits_bigint_from_parent() { // pg_dump generates FK constraints only for parent tables, not child partitions. diff --git a/pgdog/src/cli.rs b/pgdog/src/cli.rs index 11212c43c..9c4bef216 100644 --- a/pgdog/src/cli.rs +++ b/pgdog/src/cli.rs @@ -1,16 +1,19 @@ use std::ops::Deref; use std::path::PathBuf; +use std::time::Duration; use clap::{Parser, Subcommand}; use std::fs::read_to_string; use thiserror::Error; +use tokio::time::sleep; use tokio::{select, signal::ctrl_c}; -use tracing::{error, info}; +use tracing::{info, warn}; +use crate::backend::databases::databases; +use crate::backend::replication::orchestrator::Orchestrator; use crate::backend::schema::sync::config::ShardConfig; -use crate::backend::schema::sync::pg_dump::{PgDump, SyncState}; -use crate::backend::{databases::databases, replication::logical::Publisher}; -use crate::config::{config, Config, Users}; +use crate::backend::schema::sync::pg_dump::SyncState; +use crate::config::{Config, Users}; use crate::frontend::router::cli::RouterCli; /// PgDog is a PostgreSQL pooler, proxy, load balancer and query router. @@ -100,6 +103,10 @@ pub enum Commands { /// Name of the replication slot to create/use. #[arg(long)] replication_slot: Option, + + /// Don't perform pre-data schema sync. + #[arg(long)] + skip_schema_sync: bool, }, /// Copy schema from source to destination cluster. @@ -132,6 +139,32 @@ pub enum Commands { cutover: bool, }, + /// For testing purposes only. + /// + /// Performs the entire schema sync, data sync and replication flow + /// with cutover trigger. + /// + /// Use for internal testing only. To do this in production, + /// use the admin database RESHARD command. + /// + ReplicateAndCutover { + /// Source database name. + #[arg(long)] + from_database: String, + + /// Destination database name. + #[arg(long)] + to_database: String, + + /// Publication name. + #[arg(long)] + publication: String, + + /// Replication slot name. + #[arg(long)] + replication_slot: Option, + }, + /// Perform cluster configuration steps /// required for sharded operations. Setup { @@ -226,52 +259,94 @@ pub fn config_check( } } +/// FOR TESTING PURPOSES ONLY. +pub async fn replicate_and_cutover(commands: Commands) -> Result<(), Box> { + if let Commands::ReplicateAndCutover { + from_database, + to_database, + publication, + replication_slot, + } = commands + { + let mut orchestrator = Orchestrator::new( + &from_database, + &to_database, + &publication, + replication_slot.clone(), + )?; + + orchestrator.replicate_and_cutover().await?; + } + + Ok(()) +} + pub async fn data_sync(commands: Commands) -> Result<(), Box> { - let (source, destination, publication, replicate_only, sync_only, replication_slot) = - if let Commands::DataSync { - from_database, - to_database, - publication, - replicate_only, - sync_only, - replication_slot, - } = commands - { - let source = databases().schema_owner(&from_database)?; - let dest = databases().schema_owner(&to_database)?; - - ( - source, - dest, - publication, - replicate_only, - sync_only, - replication_slot, - ) - } else { - return Ok(()); - }; - - let mut publication = Publisher::new( - &source, - &publication, - config().config.general.query_parser_engine, - ); - if replicate_only { - if let Err(err) = publication.replicate(&destination, replication_slot).await { - error!("{}", err); + use crate::backend::replication::logical::Error; + + if let Commands::DataSync { + from_database, + to_database, + publication, + replicate_only, + sync_only, + replication_slot, + skip_schema_sync, + } = commands + { + let mut orchestrator = Orchestrator::new( + &from_database, + &to_database, + &publication, + replication_slot.clone(), + )?; + orchestrator.load_schema().await?; + + if !skip_schema_sync { + orchestrator.schema_sync_pre(true).await?; } - } else { - select! { - result = publication.data_sync(&destination, !sync_only, replication_slot) => { - if let Err(err) = result { - error!("{}", err); + + if !replicate_only { + select! { + result = orchestrator.data_sync() => { + result?; + } + + _ = ctrl_c() => { + warn!("abort signal received, waiting 5 seconds and performing cleanup"); + sleep(Duration::from_secs(5)).await; + + orchestrator.cleanup().await?; + + return Err(Error::DataSyncAborted.into()); } } + } + + if !sync_only { + let mut waiter = orchestrator.replicate().await?; + + select! { + result = waiter.wait() => { + result?; + } + + _ = ctrl_c() => { + warn!("abort signal received"); + + orchestrator.request_stop().await; - _ = ctrl_c() => (), + info!("waiting for replication to stop"); + waiter.wait().await?; + orchestrator.cleanup().await?; + + return Err(Error::DataSyncAborted.into()); + } + } } + } else { + return Ok(()); } Ok(()) @@ -279,54 +354,44 @@ pub async fn data_sync(commands: Commands) -> Result<(), Box Result<(), Box> { - let (source, destination, publication, dry_run, ignore_errors, data_sync_complete, cutover) = - if let Commands::SchemaSync { - from_database, - to_database, - publication, - dry_run, - ignore_errors, - data_sync_complete, - cutover, - } = commands - { - let source = databases().schema_owner(&from_database)?; - let dest = databases().schema_owner(&to_database)?; - - ( - source, - dest, - publication, - dry_run, - ignore_errors, - data_sync_complete, - cutover, - ) - } else { + if let Commands::SchemaSync { + from_database, + to_database, + publication, + dry_run, + ignore_errors, + data_sync_complete, + cutover, + } = commands + { + let mut orchestrator = Orchestrator::new(&from_database, &to_database, &publication, None)?; + orchestrator.load_schema().await?; + + if dry_run { + let state = if data_sync_complete { + SyncState::PostData + } else if cutover { + SyncState::Cutover + } else { + SyncState::PreData + }; + + let schema = orchestrator.schema()?; + for statement in schema.statements(state)? { + println!("{}", statement.deref()); + } return Ok(()); - }; - - let dump = PgDump::new(&source, &publication); - let output = dump.dump().await?; - let state = if data_sync_complete { - SyncState::PostData - } else if cutover { - SyncState::Cutover - } else { - SyncState::PreData - }; - - if state == SyncState::PreData { - ShardConfig::sync_all(&destination).await?; - } + } - if dry_run { - let queries = output.statements(state)?; - for query in queries { - println!("{}", query.deref()); + if data_sync_complete { + orchestrator.schema_sync_post(ignore_errors).await?; + } else if cutover { + orchestrator.schema_sync_cutover(ignore_errors).await?; + } else { + orchestrator.schema_sync_pre(ignore_errors).await?; } } else { - output.restore(&destination, ignore_errors, state).await?; + return Ok(()); } Ok(()) diff --git a/pgdog/src/main.rs b/pgdog/src/main.rs index b31327147..88056437f 100644 --- a/pgdog/src/main.rs +++ b/pgdog/src/main.rs @@ -153,6 +153,7 @@ async fn pgdog(command: Option) -> Result<(), Box) -> Result<(), Box) -> Result<(), Box String { } } +/// Get a human-readable duration split into days and hh:mm:ss:ms. +/// Example: "2d 03:15:42:100" or "00:05:30:250" +pub fn human_duration_display(duration: Duration) -> String { + let total_secs = duration.as_secs(); + let days = total_secs / 86400; + let hours = (total_secs % 86400) / 3600; + let minutes = (total_secs % 3600) / 60; + let seconds = total_secs % 60; + let millis = duration.subsec_millis(); + + if days > 0 { + format!( + "{}d {:02}:{:02}:{:02}:{:03}", + days, hours, minutes, seconds, millis + ) + } else { + format!("{:02}:{:02}:{:02}:{:03}", hours, minutes, seconds, millis) + } +} + // 2000-01-01T00:00:00Z static POSTGRES_EPOCH: i64 = 946684800000000000; @@ -123,6 +143,40 @@ pub fn pgdog_version() -> String { ) } +/// Format a number with commas for readability. +/// Example: 1234567 -> "1,234,567" +pub fn number_human(n: u64) -> String { + let s = n.to_string(); + let mut result = String::new(); + for (i, c) in s.chars().rev().enumerate() { + if i > 0 && i % 3 == 0 { + result.push(','); + } + result.push(c); + } + result.chars().rev().collect() +} + +/// Format a byte count into a human-readable string. +pub fn format_bytes(bytes: u64) -> String { + const KB: u64 = 1024; + const MB: u64 = KB * 1024; + const GB: u64 = MB * 1024; + const TB: u64 = GB * 1024; + + if bytes < KB { + format!("{} B", bytes) + } else if bytes < MB { + format!("{:.2} KB", bytes as f64 / KB as f64) + } else if bytes < GB { + format!("{:.2} MB", bytes as f64 / MB as f64) + } else if bytes < TB { + format!("{:.2} GB", bytes as f64 / GB as f64) + } else { + format!("{:.2} TB", bytes as f64 / TB as f64) + } +} + /// Get user and database parameters. pub fn user_database_from_params(params: &Parameters) -> (&str, &str) { let user = params.get_default("user", "postgres"); @@ -201,6 +255,83 @@ mod test { assert!(node_id().is_err()); } + #[test] + fn test_format_bytes() { + assert_eq!(format_bytes(0), "0 B"); + assert_eq!(format_bytes(1), "1 B"); + assert_eq!(format_bytes(512), "512 B"); + assert_eq!(format_bytes(1024), "1.00 KB"); + assert_eq!(format_bytes(1536), "1.50 KB"); + assert_eq!(format_bytes(1048576), "1.00 MB"); + assert_eq!(format_bytes(1572864), "1.50 MB"); + assert_eq!(format_bytes(1073741824), "1.00 GB"); + assert_eq!(format_bytes(1610612736), "1.50 GB"); + assert_eq!(format_bytes(1099511627776), "1.00 TB"); + } + + #[test] + fn test_number_human() { + assert_eq!(number_human(0), "0"); + assert_eq!(number_human(1), "1"); + assert_eq!(number_human(12), "12"); + assert_eq!(number_human(123), "123"); + assert_eq!(number_human(1234), "1,234"); + assert_eq!(number_human(12345), "12,345"); + assert_eq!(number_human(123456), "123,456"); + assert_eq!(number_human(1234567), "1,234,567"); + assert_eq!(number_human(1234567890), "1,234,567,890"); + } + + #[test] + fn test_human_duration_display() { + // Zero duration + assert_eq!( + human_duration_display(Duration::from_millis(0)), + "00:00:00:000" + ); + + // Just milliseconds + assert_eq!( + human_duration_display(Duration::from_millis(500)), + "00:00:00:500" + ); + + // Seconds and milliseconds + assert_eq!( + human_duration_display(Duration::from_millis(5500)), + "00:00:05:500" + ); + + // Minutes, seconds, milliseconds + assert_eq!( + human_duration_display(Duration::from_millis(65500)), + "00:01:05:500" + ); + + // Hours, minutes, seconds, milliseconds + assert_eq!( + human_duration_display(Duration::from_millis(3665500)), + "01:01:05:500" + ); + + // Days + assert_eq!( + human_duration_display( + Duration::from_secs(86400 + 3600 + 60 + 1) + Duration::from_millis(123) + ), + "1d 01:01:01:123" + ); + + // Multiple days + assert_eq!( + human_duration_display( + Duration::from_secs(2 * 86400 + 12 * 3600 + 30 * 60 + 45) + + Duration::from_millis(999) + ), + "2d 12:30:45:999" + ); + } + // These should run in separate processes (if using nextest). #[test] fn test_node_id_set() {