diff --git a/CHANGES.md b/CHANGES.md index 126413c..3de60cd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,7 +6,7 @@ Unreleased - Add column comment to get_columns method (#253), thanks to @dotan-mor - Fix autogenerate with ON UPDATE / DELETE (#258, #262), thanks to @idumitrescu-dn - Improve support for table/column comments (via SQLA 2.0.36) - +- Add nested transaction support (#267), thanks to @mfmarche # Version 2.0.2 January 10, 2023 diff --git a/cockroach_helper.sh b/cockroach_helper.sh new file mode 100755 index 0000000..eb150a1 --- /dev/null +++ b/cockroach_helper.sh @@ -0,0 +1,22 @@ +#!/bin/bash +COCKROACHDB=cockroach-v23.1.13.linux-amd64 +CROACHDB=~/.cache/$COCKROACHDB/cockroach + +quit_cockroachdb() { + OLDPIDNS=$(ps -o pidns -C cockroach | awk 'NR==2 {print $0}') + if [ -n "$OLDPIDNS" ]; then + pkill --ns $$ $OLDPIDNS + fi + return 0 +} + +[ -n "$HOST" ] || HOST=localhost +mkdir -p $(dirname $CROACHDB) +[[ -f "$CROACHDB" ]] || wget -qO- https://binaries.cockroachdb.com/$COCKROACHDB.tgz | tar xvz --directory ~/.cache +if [ $1 == "start" ]; then + quit_cockroachdb + $CROACHDB start-single-node --background --insecure --store=type=mem,size=10% --log-dir /tmp/ --listen-addr=$HOST:26257 --http-addr=$HOST:26301 + #$CROACHDB sql --host=$HOST:26257 --insecure -e "set sql_safe_updates=false; drop database if exists apibuilder; create database if not exists apibuilder; create user if not exists apibuilder; grant all on database apibuilder to apibuilder;" +else + quit_cockroachdb +fi diff --git a/pyproject.toml b/pyproject.toml index 4379329..47d0b5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[tool.black] +line-length = 100 + [tool.pytest.ini_options] addopts = "--tb native -v -r sfxX --maxfail=250 -p warnings -p logging --strict-markers" markers = [ diff --git a/sqlalchemy_cockroachdb/transaction.py b/sqlalchemy_cockroachdb/transaction.py index 95ef429..3ccc4bd 100644 --- a/sqlalchemy_cockroachdb/transaction.py +++ b/sqlalchemy_cockroachdb/transaction.py @@ -8,7 +8,16 @@ from .base import savepoint_state -def run_transaction(transactor, callback, max_retries=None, max_backoff=0): +class ChainTransaction: + def __init__(self, transactions=None): + self.results = [] + self.transactions = transactions or [] + + def add_result(self, result): + self.results.append(result) + + +def run_transaction(transactor, callback, max_retries=None, max_backoff=0, **kwargs): """Run a transaction with retries. ``callback()`` will be called with one argument to execute the @@ -26,15 +35,18 @@ def run_transaction(transactor, callback, max_retries=None, max_backoff=0): transaction should be retried before giving up. ``max_backoff`` is an optional integer that specifies the capped number of seconds for the exponential back-off. + ``inject_error`` forces retry loop to run via SET inject_retry_errors_enabled = 'true' + ``use_cockroach_restart``, default true, utilizes the special cockroach_restart protocol, + as outlined in: https://www.cockroachlabs.com/blog/nested-transactions-in-cockroachdb-20-1/ """ if isinstance(transactor, (sqlalchemy.engine.Connection, sqlalchemy.orm.Session)): - return _txn_retry_loop(transactor, callback, max_retries, max_backoff) + return _txn_retry_loop(transactor, callback, max_retries, max_backoff, **kwargs) elif isinstance(transactor, sqlalchemy.engine.Engine): with transactor.connect() as connection: - return _txn_retry_loop(connection, callback, max_retries, max_backoff) + return _txn_retry_loop(connection, callback, max_retries, max_backoff, **kwargs) elif isinstance(transactor, sqlalchemy.orm.sessionmaker): session = transactor() - return _txn_retry_loop(session, callback, max_retries, max_backoff) + return _txn_retry_loop(session, callback, max_retries, max_backoff, **kwargs) else: raise TypeError("don't know how to run a transaction on %s", type(transactor)) @@ -46,27 +58,32 @@ class _NestedTransaction: loop to be rewritten by the dialect. """ - def __init__(self, conn): + def __init__(self, conn, use_cockroach_restart=True): self.conn = conn + self.use_cockroach_restart = use_cockroach_restart def __enter__(self): try: - savepoint_state.cockroach_restart = True + if self.use_cockroach_restart: + savepoint_state.cockroach_restart = True self.txn = self.conn.begin_nested() - if isinstance(self.conn, sqlalchemy.orm.Session): + if self.use_cockroach_restart and isinstance(self.conn, sqlalchemy.orm.Session): # Sessions are lazy and don't execute the savepoint # query until you ask for the connection. self.conn.connection() finally: - savepoint_state.cockroach_restart = False + if self.use_cockroach_restart: + savepoint_state.cockroach_restart = False return self def __exit__(self, typ, value, tb): try: - savepoint_state.cockroach_restart = True + if self.use_cockroach_restart: + savepoint_state.cockroach_restart = True self.txn.__exit__(typ, value, tb) finally: - savepoint_state.cockroach_restart = False + if self.use_cockroach_restart: + savepoint_state.cockroach_restart = False def retry_exponential_backoff(retry_count: int, max_backoff: int = 0) -> None: @@ -81,45 +98,62 @@ def retry_exponential_backoff(retry_count: int, max_backoff: int = 0) -> None: :return: None """ - sleep_secs = uniform(0, min(max_backoff, 0.1 * (2 ** retry_count))) + sleep_secs = uniform(0, min(max_backoff, 0.1 * (2**retry_count))) sleep(sleep_secs) -def _txn_retry_loop(conn, callback, max_retries, max_backoff): - """Inner transaction retry loop. - - ``conn`` may be either a Connection or a Session, but they both - have compatible ``begin()`` and ``begin_nested()`` methods. - """ +def run_in_nested_transaction( + conn, callback, max_retries, max_backoff, inject_error=False, **kwargs +): if isinstance(conn, sqlalchemy.orm.Session): dbapi_name = conn.bind.driver else: dbapi_name = conn.engine.driver retry_count = 0 - with conn.begin(): - while True: - try: - with _NestedTransaction(conn): - ret = callback(conn) - return ret - except sqlalchemy.exc.DatabaseError as e: - if max_retries is not None and retry_count >= max_retries: - raise - do_retry = False - if dbapi_name == "psycopg2": - import psycopg2 - import psycopg2.errorcodes - if isinstance(e.orig, psycopg2.OperationalError): - if e.orig.pgcode == psycopg2.errorcodes.SERIALIZATION_FAILURE: - do_retry = True - else: - import psycopg - if isinstance(e.orig, psycopg.errors.SerializationFailure): - do_retry = True - if do_retry: - retry_count += 1 - if max_backoff > 0: - retry_exponential_backoff(retry_count, max_backoff) - continue + while True: + if inject_error and retry_count == 0: + conn.execute(sqlalchemy.text("SET inject_retry_errors_enabled = 'true'")) + elif inject_error: + conn.execute(sqlalchemy.text("SET inject_retry_errors_enabled = 'false'")) + try: + with _NestedTransaction(conn, **kwargs): + return callback(conn) + except sqlalchemy.exc.DatabaseError as e: + if max_retries is not None and retry_count >= max_retries: raise + do_retry = False + if dbapi_name == "psycopg2": + import psycopg2 + import psycopg2.errorcodes + + if isinstance(e.orig, psycopg2.OperationalError): + if e.orig.pgcode == psycopg2.errorcodes.SERIALIZATION_FAILURE: + do_retry = True + else: + import psycopg + + if isinstance(e.orig, psycopg.errors.SerializationFailure): + do_retry = True + if do_retry: + retry_count += 1 + if max_backoff > 0: + retry_exponential_backoff(retry_count, max_backoff) + continue + raise + + +def _txn_retry_loop(conn, callback, max_retries, max_backoff, **kwargs): + """Inner transaction retry loop. + + ``conn`` may be either a Connection or a Session, but they both + have compatible ``begin()`` and ``begin_nested()`` methods. + """ + with conn.begin(): + result = run_in_nested_transaction(conn, callback, max_retries, max_backoff, **kwargs) + if isinstance(result, ChainTransaction): + for transaction in result.transactions: + result.add_result( + run_in_nested_transaction(conn, transaction, max_retries, max_backoff, **kwargs) + ) + return result diff --git a/test/test_run_transaction_core.py b/test/test_run_transaction_core.py index da14e6b..c1431b9 100644 --- a/test/test_run_transaction_core.py +++ b/test/test_run_transaction_core.py @@ -3,8 +3,11 @@ from sqlalchemy.testing import fixtures from sqlalchemy.types import Integer import threading +from sqlalchemy.orm import sessionmaker, scoped_session + from sqlalchemy_cockroachdb import run_transaction +from sqlalchemy_cockroachdb.transaction import ChainTransaction meta = MetaData() @@ -25,7 +28,9 @@ def setup_method(self, method): ) def teardown_method(self, method): - meta.drop_all(testing.db) + session = scoped_session(sessionmaker(bind=testing.db)) + session.query(account_table).delete() + session.commit() def get_balances(self, conn): """Returns the balances of the two accounts as a list.""" @@ -134,3 +139,33 @@ def txn_body(conn): with testing.db.connect() as conn: rs = run_transaction(conn, txn_body) assert rs[0] == (1, 100) + + def test_run_transaction_retry_with_nested(self): + def txn_body(conn): + rs = conn.execute(text("select acct, balance from account where acct = 1")) + conn.execute(text("select crdb_internal.force_retry('1s')")) + return [r for r in rs] + + with testing.db.connect() as conn: + rs = run_transaction(conn, txn_body, use_cockroach_restart=False) + assert rs[0] == (1, 100) + + def test_run_chained_transaction(self): + def txn_body(conn): + # first transaction inserts + conn.execute(account_table.insert(), [dict(acct=99, balance=100)]) + conn.execute(text("select crdb_internal.force_retry('1s')")) + + def _get_val(s): + rs = s.execute(text("select acct, balance from account where acct = 99")) + return [r for r in rs] + + # chain the get into a separate nested transaction, so that the value + # in the previous nested transaction is flushed and available + return ChainTransaction([lambda s: _get_val(s), lambda s: _get_val(s)]) + + with testing.db.connect() as conn: + rs = run_transaction(conn, txn_body, use_cockroach_restart=False) + assert len(rs.results) == 2 + assert rs.results[0][0] == (99, 100) + assert rs.results[1][0] == (99, 100)