diff --git a/README.md b/README.md index f8574c5..4db91ad 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,10 @@

- +> [!WARNING] +> **Breaking Changes in v5.0.0** +> +> Starting with version 5.0.0, Pydantic support will become optional. The default implementations of `Request`, `Response`, `DomainEvent`, and `NotificationEvent` will be migrated to dataclasses-based implementations. ## Overview diff --git a/examples/saga_recovery_scheduler.py b/examples/saga_recovery_scheduler.py new file mode 100644 index 0000000..55d2867 --- /dev/null +++ b/examples/saga_recovery_scheduler.py @@ -0,0 +1,661 @@ +""" +Example: Saga Recovery Scheduler (while + sleep) + +This example demonstrates how to run a simple recovery scheduler that periodically +scans for stuck or failed sagas and recovers them. The scheduler uses a plain +while loop with asyncio.sleep, suitable for a dedicated worker process or +a background task. + +PROBLEM: Recovering Failed Sagas in Production +============================================= + +When sagas run in production, processes can crash, time out, or be restarted. +Incomplete sagas (RUNNING, COMPENSATING, FAILED) must be recovered so that: +- Forward execution can complete +- Compensation can finish +- The system reaches eventual consistency + +A recovery job must: +1. Find sagas that need recovery (not currently being executed) +2. Avoid picking the same saga twice (e.g. limit by recovery_attempts) +3. Run periodically without blocking the main application + +SOLUTION: While Loop + get_sagas_for_recovery +============================================= + +Use get_sagas_for_recovery(limit=..., max_recovery_attempts=..., stale_after_seconds=...) +to select only "stale" sagas (updated_at older than threshold), then call +recover_saga for each. On recovery failure, recover_saga calls +increment_recovery_attempts under the hood so the saga can be retried or +excluded later; callers only need to call recover_saga. + +================================================================================ +HOW TO RUN THIS EXAMPLE +================================================================================ + +Run the example: + python examples/saga_recovery_scheduler.py + +The example will: +- Create an in-memory storage and one interrupted saga (simulated crash) +- Run the recovery scheduler loop for a few iterations +- Recover the interrupted saga on the first iteration +- Show that subsequent iterations find no sagas to recover + +================================================================================ +WHAT THIS EXAMPLE DEMONSTRATES +================================================================================ + +1. Recovery scheduler loop: + - while True with asyncio.sleep(interval_seconds) + - get_sagas_for_recovery(limit, max_recovery_attempts, stale_after_seconds) + - Per-saga recover_saga() only; increment_recovery_attempts is done inside recover_saga on failure + +2. Staleness filter (stale_after_seconds): + - Only sagas not updated recently are considered (avoids recovering + sagas that are currently being executed by another worker) + +3. Max recovery attempts: + - Sagas that fail recovery too many times are excluded from selection + - After increment_recovery_attempts, they can be retried until max is reached + +================================================================================ +REQUIREMENTS +================================================================================ + +Make sure you have installed: + - cqrs (this package) + - pydantic (for context models) + +This example declares its own domain model (OrderContext), step handlers, +services, saga (OrderSaga), and container; it does not depend on other examples. + +================================================================================ +""" + +import asyncio +import dataclasses +import datetime +import logging +import typing +import uuid + +from cqrs import container as cqrs_container +from cqrs.events.event import Event +from cqrs.response import Response +from cqrs.saga.models import SagaContext +from cqrs.saga.recovery import recover_saga +from cqrs.saga.saga import Saga +from cqrs.saga.step import SagaStepHandler, SagaStepResult +from cqrs.saga.storage.enums import SagaStatus, SagaStepStatus +from cqrs.saga.storage.memory import MemorySagaStorage +from cqrs.saga.storage.protocol import ISagaStorage + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Domain Models +# ============================================================================ + + +@dataclasses.dataclass +class OrderContext(SagaContext): + """Shared context passed between all saga steps.""" + + order_id: str + user_id: str + items: list[str] + total_amount: float + shipping_address: str + + inventory_reservation_id: str | None = None + payment_id: str | None = None + shipment_id: str | None = None + + +# ============================================================================ +# Step Responses +# ============================================================================ + + +class ReserveInventoryResponse(Response): + """Response from inventory reservation step.""" + + reservation_id: str + items_reserved: list[str] + + +class ProcessPaymentResponse(Response): + """Response from payment processing step.""" + + payment_id: str + amount_charged: float + transaction_id: str + + +class ShipOrderResponse(Response): + """Response from shipping step.""" + + shipment_id: str + tracking_number: str + estimated_delivery: str + + +# ============================================================================ +# Domain Events (minimal for step handlers) +# ============================================================================ + + +class InventoryReservedEvent(Event, frozen=True): + """Event emitted when inventory is reserved.""" + + order_id: str + reservation_id: str + items: list[str] + + +class PaymentProcessedEvent(Event, frozen=True): + """Event emitted when payment is processed.""" + + order_id: str + payment_id: str + amount: float + + +class OrderShippedEvent(Event, frozen=True): + """Event emitted when order is shipped.""" + + order_id: str + shipment_id: str + tracking_number: str + + +# ============================================================================ +# Mock Services +# ============================================================================ + + +class InventoryService: + """Mock inventory service for reserving and releasing items.""" + + def __init__(self) -> None: + self._reservations: dict[str, list[str]] = {} + self._available_items: dict[str, int] = { + "item_1": 10, + "item_2": 5, + "item_3": 8, + } + + async def reserve_items(self, order_id: str, items: list[str]) -> str: + reservation_id = f"reservation_{order_id}" + reserved_items = [] + for item_id in items: + if item_id not in self._available_items: + raise ValueError(f"Item {item_id} not found") + if self._available_items[item_id] <= 0: + raise ValueError(f"Insufficient inventory for {item_id}") + self._available_items[item_id] -= 1 + reserved_items.append(item_id) + self._reservations[reservation_id] = reserved_items + logger.info(" ✓ Reserved items %s for order %s", reserved_items, order_id) + return reservation_id + + async def release_items(self, reservation_id: str) -> None: + if reservation_id not in self._reservations: + return + items = self._reservations[reservation_id] + for item_id in items: + self._available_items[item_id] += 1 + del self._reservations[reservation_id] + logger.info(" ↻ Released items %s from reservation %s", items, reservation_id) + + +class PaymentService: + """Mock payment service for processing payments and refunds.""" + + def __init__(self) -> None: + self._payments: dict[str, float] = {} + self._transaction_counter = 0 + + async def charge(self, order_id: str, amount: float) -> tuple[str, str]: + if amount <= 0: + raise ValueError("Payment amount must be positive") + self._transaction_counter += 1 + payment_id = f"payment_{order_id}" + transaction_id = f"txn_{self._transaction_counter:06d}" + self._payments[payment_id] = amount + logger.info( + " ✓ Charged $%.2f for order %s (transaction: %s)", + amount, + order_id, + transaction_id, + ) + return payment_id, transaction_id + + async def refund(self, payment_id: str) -> None: + if payment_id not in self._payments: + return + amount = self._payments[payment_id] + del self._payments[payment_id] + logger.info(" ↻ Refunded $%.2f for payment %s", amount, payment_id) + + +class ShippingService: + """Mock shipping service for creating shipments.""" + + def __init__(self) -> None: + self._shipments: dict[str, str] = {} + self._tracking_counter = 0 + + async def create_shipment( + self, + order_id: str, + items: list[str], + address: str, + ) -> tuple[str, str]: + if not address: + raise ValueError("Shipping address is required") + self._tracking_counter += 1 + shipment_id = f"shipment_{order_id}" + tracking_number = f"TRACK{self._tracking_counter:08d}" + self._shipments[shipment_id] = tracking_number + logger.info( + " ✓ Created shipment %s for order %s (tracking: %s)", + shipment_id, + order_id, + tracking_number, + ) + return shipment_id, tracking_number + + async def cancel_shipment(self, shipment_id: str) -> None: + if shipment_id not in self._shipments: + return + tracking_number = self._shipments[shipment_id] + del self._shipments[shipment_id] + logger.info( + " ↻ Cancelled shipment %s (tracking: %s)", + shipment_id, + tracking_number, + ) + + +# ============================================================================ +# Saga Step Handlers +# ============================================================================ + + +class ReserveInventoryStep( + SagaStepHandler[OrderContext, ReserveInventoryResponse], +): + """Step 1: Reserve inventory items for the order.""" + + def __init__(self, inventory_service: InventoryService) -> None: + self._inventory_service = inventory_service + self._events: list[Event] = [] + + @property + def events(self) -> list[Event]: + return self._events.copy() + + async def act( + self, + context: OrderContext, + ) -> SagaStepResult[OrderContext, ReserveInventoryResponse]: + reservation_id = await self._inventory_service.reserve_items( + order_id=context.order_id, + items=context.items, + ) + context.inventory_reservation_id = reservation_id + self._events.append( + InventoryReservedEvent( + order_id=context.order_id, + reservation_id=reservation_id, + items=context.items, + ), + ) + response = ReserveInventoryResponse( + reservation_id=reservation_id, + items_reserved=context.items, + ) + return self._generate_step_result(response) + + async def compensate(self, context: OrderContext) -> None: + if context.inventory_reservation_id: + await self._inventory_service.release_items( + context.inventory_reservation_id, + ) + + +class ProcessPaymentStep( + SagaStepHandler[OrderContext, ProcessPaymentResponse], +): + """Step 2: Process payment for the order.""" + + def __init__(self, payment_service: PaymentService) -> None: + self._payment_service = payment_service + self._events: list[Event] = [] + + @property + def events(self) -> list[Event]: + return self._events.copy() + + async def act( + self, + context: OrderContext, + ) -> SagaStepResult[OrderContext, ProcessPaymentResponse]: + payment_id, transaction_id = await self._payment_service.charge( + order_id=context.order_id, + amount=context.total_amount, + ) + context.payment_id = payment_id + self._events.append( + PaymentProcessedEvent( + order_id=context.order_id, + payment_id=payment_id, + amount=context.total_amount, + ), + ) + response = ProcessPaymentResponse( + payment_id=payment_id, + amount_charged=context.total_amount, + transaction_id=transaction_id, + ) + return self._generate_step_result(response) + + async def compensate(self, context: OrderContext) -> None: + if context.payment_id: + await self._payment_service.refund(context.payment_id) + + +class ShipOrderStep(SagaStepHandler[OrderContext, ShipOrderResponse]): + """Step 3: Create shipment for the order.""" + + def __init__(self, shipping_service: ShippingService) -> None: + self._shipping_service = shipping_service + self._events: list[Event] = [] + + @property + def events(self) -> list[Event]: + return self._events.copy() + + async def act( + self, + context: OrderContext, + ) -> SagaStepResult[OrderContext, ShipOrderResponse]: + shipment_id, tracking_number = await self._shipping_service.create_shipment( + order_id=context.order_id, + items=context.items, + address=context.shipping_address, + ) + context.shipment_id = shipment_id + self._events.append( + OrderShippedEvent( + order_id=context.order_id, + shipment_id=shipment_id, + tracking_number=tracking_number, + ), + ) + response = ShipOrderResponse( + shipment_id=shipment_id, + tracking_number=tracking_number, + estimated_delivery="2024-12-25", + ) + return self._generate_step_result(response) + + async def compensate(self, context: OrderContext) -> None: + if context.shipment_id: + await self._shipping_service.cancel_shipment(context.shipment_id) + + +# ============================================================================ +# Saga Definition +# ============================================================================ + + +class OrderSaga(Saga[OrderContext]): + """Order processing saga with three steps.""" + + steps = [ + ReserveInventoryStep, + ProcessPaymentStep, + ShipOrderStep, + ] + + +# ============================================================================ +# Container +# ============================================================================ + + +class SimpleContainer(cqrs_container.Container[typing.Any]): + """Simple container for resolving step handlers.""" + + def __init__( + self, + inventory_service: InventoryService, + payment_service: PaymentService, + shipping_service: ShippingService, + ) -> None: + self._services = { + InventoryService: inventory_service, + PaymentService: payment_service, + ShippingService: shipping_service, + } + self._external_container: typing.Any = None + + @property + def external_container(self) -> typing.Any: + return self._external_container + + def attach_external_container(self, container: typing.Any) -> None: + self._external_container = container + + async def resolve(self, type_: type) -> typing.Any: + if type_ in self._services: + return self._services[type_] + if type_ == ReserveInventoryStep: + return ReserveInventoryStep(self._services[InventoryService]) + if type_ == ProcessPaymentStep: + return ProcessPaymentStep(self._services[PaymentService]) + if type_ == ShipOrderStep: + return ShipOrderStep(self._services[ShippingService]) + raise ValueError(f"Unknown type: {type_}") + + +# ============================================================================ +# Scheduler configuration +# ============================================================================ + +RECOVERY_INTERVAL_SECONDS = 2 +RECOVERY_BATCH_LIMIT = 10 +MAX_RECOVERY_ATTEMPTS = 5 +STALE_AFTER_SECONDS = 60 + + +# ============================================================================ +# Recovery scheduler +# ============================================================================ + + +def make_container() -> SimpleContainer: + """Create a fresh container with services (e.g. after process restart).""" + return SimpleContainer( + inventory_service=InventoryService(), + payment_service=PaymentService(), + shipping_service=ShippingService(), + ) + + +async def run_recovery_iteration( + storage: ISagaStorage, + saga: OrderSaga, + context_builder: typing.Type[OrderContext], +) -> int: + """ + Run one recovery iteration: fetch stale sagas, recover each. + + recover_saga() increments recovery_attempts on failure under the hood; + the caller only calls recover_saga(). + + Returns the number of sagas processed (recovered or failed). + """ + ids = await storage.get_sagas_for_recovery( + limit=RECOVERY_BATCH_LIMIT, + max_recovery_attempts=MAX_RECOVERY_ATTEMPTS, + stale_after_seconds=STALE_AFTER_SECONDS, + ) + if not ids: + return 0 + + container = make_container() + processed = 0 + for saga_id in ids: + try: + logger.info("Recovering saga %s...", saga_id) + await recover_saga(saga, saga_id, context_builder, container, storage) + logger.info("Saga %s recovered successfully.", saga_id) + processed += 1 + except RuntimeError as e: + if "recovered in" in str(e) and "state" in str(e): + logger.info("Saga %s recovery completed compensation: %s", saga_id, e) + processed += 1 + else: + logger.exception("Saga %s recovery failed: %s", saga_id, e) + processed += 1 + except Exception as e: + logger.exception("Saga %s recovery failed: %s", saga_id, e) + processed += 1 + return processed + + +async def recovery_loop( + storage: ISagaStorage, + *, + interval_seconds: float = RECOVERY_INTERVAL_SECONDS, + max_iterations: int | None = None, +) -> None: + """ + Run the recovery scheduler loop. + + Args: + storage: Saga storage (e.g. MemorySagaStorage or SqlAlchemySagaStorage). + interval_seconds: Sleep duration between iterations. + max_iterations: If set, stop after this many iterations (for demo). + None = run until cancelled. + """ + saga = OrderSaga() + iteration = 0 + while True: + iteration += 1 + logger.info("Recovery iteration %s", iteration) + try: + processed = await run_recovery_iteration( + storage, + saga, + OrderContext, + ) + if processed > 0: + logger.info("Processed %s saga(s) this iteration.", processed) + else: + logger.debug("No sagas to recover.") + except asyncio.CancelledError: + logger.info("Recovery loop cancelled.") + raise + except Exception as e: + logger.exception("Recovery iteration failed: %s", e) + + if max_iterations is not None and iteration >= max_iterations: + logger.info("Reached max_iterations=%s, stopping.", max_iterations) + break + await asyncio.sleep(interval_seconds) + + +# ============================================================================ +# Demo: create one interrupted saga, then run scheduler +# ============================================================================ + + +async def create_interrupted_saga(storage: MemorySagaStorage) -> uuid.UUID: + """ + Create one saga in RUNNING state (simulating crash after first step). + Returns without recovering. + """ + saga_id = uuid.uuid4() + context = OrderContext( + order_id="order_scheduler_demo", + user_id="user_1", + items=["item_1"], + total_amount=99.99, + shipping_address="123 Main St", + ) + + await storage.create_saga( + saga_id=saga_id, + name="order_saga", + context=context.to_dict(), + ) + await storage.update_status(saga_id, SagaStatus.RUNNING) + await storage.log_step( + saga_id, + "ReserveInventoryStep", + "act", + SagaStepStatus.STARTED, + ) + await storage.log_step( + saga_id, + "ReserveInventoryStep", + "act", + SagaStepStatus.COMPLETED, + ) + ctx_dict = context.to_dict() + ctx_dict["inventory_reservation_id"] = "reservation_order_scheduler_demo" + await storage.update_context(saga_id, ctx_dict) + + logger.info("Created interrupted saga %s (RUNNING, one step done).", saga_id) + return saga_id + + +async def main() -> None: + """Run the recovery scheduler example.""" + print("\n" + "=" * 70) + print("SAGA RECOVERY SCHEDULER EXAMPLE") + print("=" * 70) + print("\nThis example demonstrates:") + print(" 1. A simple while-loop recovery scheduler with asyncio.sleep") + print( + " 2. get_sagas_for_recovery(limit, max_recovery_attempts, stale_after_seconds)", + ) + print( + " 3. recover_saga() per saga (increment_recovery_attempts on failure is internal)", + ) + + storage = MemorySagaStorage() + + saga_id = await create_interrupted_saga(storage) + storage._sagas[saga_id]["updated_at"] = datetime.datetime.now( + datetime.timezone.utc, + ) - datetime.timedelta(seconds=STALE_AFTER_SECONDS + 10) + + print("\nRunning recovery loop for 3 iterations (interval=2s)...") + print(" Iteration 1 should recover the interrupted saga.") + print(" Iteration 2 and 3 should find no sagas.\n") + + await recovery_loop( + storage, + interval_seconds=RECOVERY_INTERVAL_SECONDS, + max_iterations=3, + ) + + status, context_data, _ = await storage.load_saga_state(saga_id) + print("\n" + "-" * 70) + print(f"Final state of saga {saga_id}:") + print(f" Status: {status}") + print("-" * 70) + print("\nEXAMPLE COMPLETED") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index fd63210..61b1944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ maintainers = [{name = "Vadim Kozyrevskiy", email = "vadikko2@mail.ru"}] name = "python-cqrs" readme = "README.md" requires-python = ">=3.10" -version = "4.6.5" +version = "4.7.0" [project.optional-dependencies] aiobreaker = ["aiobreaker>=0.3.0"] diff --git a/src/cqrs/saga/recovery.py b/src/cqrs/saga/recovery.py index 9e0f28f..5d80dc4 100644 --- a/src/cqrs/saga/recovery.py +++ b/src/cqrs/saga/recovery.py @@ -29,6 +29,11 @@ async def recover_saga( Already completed steps will be skipped. If the saga was in a compensating state, compensation will resume. + On recovery failure (exception during resume), the storage's + increment_recovery_attempts is called automatically so the saga can be + retried or excluded by get_sagas_for_recovery(max_recovery_attempts=...). + Callers do not need to call increment_recovery_attempts themselves. + Args: saga: The saga orchestrator instance. saga_id: The ID of the saga to recover. @@ -109,11 +114,12 @@ async def recover_saga( ) # Re-raise to allow callers to handle this case raise - # For other RuntimeErrors, log and re-raise + # For other RuntimeErrors, recovery failed: increment attempts and re-raise logger.error(f"Saga {saga_id} recovery ended with error: {e}") + await storage.increment_recovery_attempts(saga_id, new_status=SagaStatus.FAILED) raise except Exception as e: logger.error(f"Saga {saga_id} recovery ended with error: {e}") - # The transaction handles exception and runs compensation, so the saga state - # should be updated to FAILED (or COMPENSATED) in storage. + # Recovery failed: increment attempts so saga can be retried or excluded later + await storage.increment_recovery_attempts(saga_id, new_status=SagaStatus.FAILED) raise diff --git a/src/cqrs/saga/saga.py b/src/cqrs/saga/saga.py index 006efe4..acf0401 100644 --- a/src/cqrs/saga/saga.py +++ b/src/cqrs/saga/saga.py @@ -142,9 +142,14 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> bool: - # If an exception occurred, compensate all completed steps - # Only compensate if not already compensated in __aiter__ - if exc_val is not None and not self._compensated: + # If an exception occurred, compensate all completed steps. + # Do not compensate on GeneratorExit: consumer stopped iteration intentionally + # (e.g. to resume later), which is not a failure. + if ( + exc_val is not None + and exc_type is not GeneratorExit + and not self._compensated + ): self._error = exc_val await self._compensate() return False # Don't suppress the exception diff --git a/src/cqrs/saga/storage/memory.py b/src/cqrs/saga/storage/memory.py index acb4599..adc0c40 100644 --- a/src/cqrs/saga/storage/memory.py +++ b/src/cqrs/saga/storage/memory.py @@ -37,6 +37,7 @@ async def create_saga( "created_at": now, "updated_at": now, "version": 1, + "recovery_attempts": 0, } self._logs[saga_id] = [] @@ -115,3 +116,40 @@ async def get_step_history( return [] # Sort by timestamp return sorted(self._logs[saga_id], key=lambda x: x.timestamp) + + async def get_sagas_for_recovery( + self, + limit: int, + max_recovery_attempts: int = 5, + stale_after_seconds: int | None = None, + ) -> list[uuid.UUID]: + recoverable = (SagaStatus.RUNNING, SagaStatus.COMPENSATING, SagaStatus.FAILED) + now = datetime.datetime.now(datetime.timezone.utc) + threshold = ( + (now - datetime.timedelta(seconds=stale_after_seconds)) + if stale_after_seconds is not None + else None + ) + candidates = [ + sid + for sid, data in self._sagas.items() + if data["status"] in recoverable + and data.get("recovery_attempts", 0) < max_recovery_attempts + and (threshold is None or data["updated_at"] < threshold) + ] + candidates.sort(key=lambda sid: self._sagas[sid]["updated_at"]) + return candidates[:limit] + + async def increment_recovery_attempts( + self, + saga_id: uuid.UUID, + new_status: SagaStatus | None = None, + ) -> None: + if saga_id not in self._sagas: + raise ValueError(f"Saga {saga_id} not found") + data = self._sagas[saga_id] + data["recovery_attempts"] = data.get("recovery_attempts", 0) + 1 + data["updated_at"] = datetime.datetime.now(datetime.timezone.utc) + data["version"] += 1 + if new_status is not None: + data["status"] = new_status diff --git a/src/cqrs/saga/storage/protocol.py b/src/cqrs/saga/storage/protocol.py index 2211fca..74c6157 100644 --- a/src/cqrs/saga/storage/protocol.py +++ b/src/cqrs/saga/storage/protocol.py @@ -70,3 +70,43 @@ async def get_step_history( saga_id: uuid.UUID, ) -> list[SagaLogEntry]: """Get step execution history.""" + + @abc.abstractmethod + async def get_sagas_for_recovery( + self, + limit: int, + max_recovery_attempts: int = 5, + stale_after_seconds: int | None = None, + ) -> list[uuid.UUID]: + """Return saga IDs that need recovery. + + Args: + limit: Maximum number of saga IDs to return. + max_recovery_attempts: Only include sagas with recovery_attempts + strictly less than this value. Default 5. + stale_after_seconds: If set, only include sagas whose updated_at + is older than (now_utc - stale_after_seconds). Use this to + avoid picking sagas that are currently being executed (recently + updated). None means no staleness filter (backward compatible). + + Returns: + List of saga IDs (RUNNING, COMPENSATING, or FAILED), ordered by + updated_at ascending, with recovery_attempts < max_recovery_attempts, + and optionally updated_at older than the staleness threshold. + """ + + @abc.abstractmethod + async def increment_recovery_attempts( + self, + saga_id: uuid.UUID, + new_status: SagaStatus | None = None, + ) -> None: + """Atomically increment recovery attempts after a failed recovery. + + Updates recovery_attempts += 1, updated_at = now(), and optionally + status. Also increments version for optimistic locking. + + Args: + saga_id: The saga to update. + new_status: If provided, set saga status to this value (e.g. FAILED). + """ diff --git a/src/cqrs/saga/storage/sqlalchemy.py b/src/cqrs/saga/storage/sqlalchemy.py index 3ab5379..a4da774 100644 --- a/src/cqrs/saga/storage/sqlalchemy.py +++ b/src/cqrs/saga/storage/sqlalchemy.py @@ -78,6 +78,13 @@ class SagaExecutionModel(Base): onupdate=func.now(), comment="Last update timestamp", ) + recovery_attempts = sqlalchemy.Column( + sqlalchemy.Integer, + nullable=False, + default=0, + server_default=sqlalchemy.text("0"), + comment="Number of recovery attempts", + ) class SagaLogModel(Base): @@ -143,6 +150,7 @@ async def create_saga( status=SagaStatus.PENDING, context=context, version=1, + recovery_attempts=0, ) session.add(execution) await session.commit() @@ -303,3 +311,57 @@ async def get_step_history( ) for row in rows ] + + async def get_sagas_for_recovery( + self, + limit: int, + max_recovery_attempts: int = 5, + stale_after_seconds: int | None = None, + ) -> list[uuid.UUID]: + recoverable = ( + SagaStatus.RUNNING, + SagaStatus.COMPENSATING, + SagaStatus.FAILED, + ) + async with self.session_factory() as session: + stmt = ( + sqlalchemy.select(SagaExecutionModel.id) + .where(SagaExecutionModel.status.in_(recoverable)) + .where(SagaExecutionModel.recovery_attempts < max_recovery_attempts) + ) + if stale_after_seconds is not None: + threshold = datetime.datetime.now( + datetime.timezone.utc, + ) - datetime.timedelta( + seconds=stale_after_seconds, + ) + stmt = stmt.where(SagaExecutionModel.updated_at < threshold) + stmt = stmt.order_by(SagaExecutionModel.updated_at.asc()).limit(limit) + result = await session.execute(stmt) + rows = result.scalars().all() + return [typing.cast(uuid.UUID, row) for row in rows] + + async def increment_recovery_attempts( + self, + saga_id: uuid.UUID, + new_status: SagaStatus | None = None, + ) -> None: + async with self.session_factory() as session: + try: + values: dict[str, typing.Any] = { + "recovery_attempts": SagaExecutionModel.recovery_attempts + 1, + "version": SagaExecutionModel.version + 1, + } + if new_status is not None: + values["status"] = new_status + result = await session.execute( + sqlalchemy.update(SagaExecutionModel) + .where(SagaExecutionModel.id == saga_id) + .values(**values), + ) + if result.rowcount == 0: # type: ignore[attr-defined] + raise ValueError(f"Saga {saga_id} not found") + await session.commit() + except SQLAlchemyError: + await session.rollback() + raise diff --git a/tests/integration/test_saga_storage_memory.py b/tests/integration/test_saga_storage_memory.py index 652fa79..2aadc44 100644 --- a/tests/integration/test_saga_storage_memory.py +++ b/tests/integration/test_saga_storage_memory.py @@ -1,5 +1,6 @@ """Integration tests for MemorySagaStorage.""" +import datetime import uuid import pytest @@ -159,3 +160,219 @@ async def test_compensation_scenario( assert history[3].action == "compensate" assert history[2].details == "Payment refunded" assert history[3].details == "Inventory released" + + +class TestRecoveryMemory: + """Integration tests for get_sagas_for_recovery and increment_recovery_attempts (Memory).""" + + # --- get_sagas_for_recovery: positive --- + + async def test_get_sagas_for_recovery_returns_recoverable_sagas( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: returns RUNNING, COMPENSATING, FAILED sagas only.""" + id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + for sid in (id1, id2, id3): + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(id1, SagaStatus.RUNNING) + await storage.update_status(id2, SagaStatus.COMPENSATING) + await storage.update_status(id3, SagaStatus.FAILED) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert set(ids) == {id1, id2, id3} + assert len(ids) == 3 + + async def test_get_sagas_for_recovery_respects_limit( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: returns at most `limit` saga IDs.""" + for i in range(5): + sid = uuid.uuid4() + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.RUNNING) + + ids = await storage.get_sagas_for_recovery(limit=2) + assert len(ids) == 2 + + async def test_get_sagas_for_recovery_respects_max_recovery_attempts( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: only returns sagas with recovery_attempts < max_recovery_attempts.""" + id_low = uuid.uuid4() + id_high = uuid.uuid4() + await storage.create_saga(saga_id=id_low, name="saga", context=test_context) + await storage.create_saga(saga_id=id_high, name="saga", context=test_context) + await storage.update_status(id_low, SagaStatus.RUNNING) + await storage.update_status(id_high, SagaStatus.RUNNING) + # id_high: simulate 5 failed recovery attempts (default max is 5) + for _ in range(5): + await storage.increment_recovery_attempts(id_high) + + ids = await storage.get_sagas_for_recovery(limit=10, max_recovery_attempts=5) + assert id_low in ids + assert id_high not in ids + + async def test_get_sagas_for_recovery_ordered_by_updated_at( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: result ordered by updated_at ascending (oldest first).""" + id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + for sid in (id1, id2, id3): + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.RUNNING) + # touch id2 so its updated_at is latest + await storage.update_context(id2, {**test_context, "touched": True}) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert len(ids) == 3 + # id2 was updated last, so should be last in list (oldest first) + assert ids[-1] == id2 + + async def test_get_sagas_for_recovery_stale_after_excludes_recently_updated( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: with stale_after_seconds, recently updated sagas are excluded.""" + id_recent = uuid.uuid4() + await storage.create_saga(saga_id=id_recent, name="saga", context=test_context) + await storage.update_status(id_recent, SagaStatus.RUNNING) + # No manual change to updated_at: it was just updated + ids = await storage.get_sagas_for_recovery( + limit=10, + stale_after_seconds=60, + ) + assert id_recent not in ids + + async def test_get_sagas_for_recovery_stale_after_includes_old_updated( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: with stale_after_seconds, sagas with old updated_at are included.""" + id_old = uuid.uuid4() + await storage.create_saga(saga_id=id_old, name="saga", context=test_context) + await storage.update_status(id_old, SagaStatus.RUNNING) + storage._sagas[id_old]["updated_at"] = datetime.datetime.now( + datetime.timezone.utc, + ) - datetime.timedelta(seconds=120) + ids = await storage.get_sagas_for_recovery( + limit=10, + stale_after_seconds=60, + ) + assert id_old in ids + + async def test_get_sagas_for_recovery_without_stale_after_unchanged_behavior( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Backward compat: without stale_after_seconds, recently updated sagas are included.""" + sid = uuid.uuid4() + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.RUNNING) + ids = await storage.get_sagas_for_recovery(limit=10) + assert sid in ids + + # --- get_sagas_for_recovery: negative --- + + async def test_get_sagas_for_recovery_empty_when_none_recoverable( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Negative: returns empty list when no recoverable sagas.""" + sid = uuid.uuid4() + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + # PENDING and COMPLETED are not recoverable + await storage.update_status(sid, SagaStatus.COMPLETED) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert ids == [] + + async def test_get_sagas_for_recovery_excludes_pending_and_completed( + self, + storage: MemorySagaStorage, + test_context: dict[str, str], + ) -> None: + """Negative: PENDING and COMPLETED sagas are not returned.""" + id_pending = uuid.uuid4() + id_completed = uuid.uuid4() + await storage.create_saga(saga_id=id_pending, name="saga", context=test_context) + await storage.create_saga( + saga_id=id_completed, + name="saga", + context=test_context, + ) + await storage.update_status(id_completed, SagaStatus.COMPLETED) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert id_pending not in ids + assert id_completed not in ids + + # --- increment_recovery_attempts: positive --- + + async def test_increment_recovery_attempts_increments_counter( + self, + storage: MemorySagaStorage, + saga_id: uuid.UUID, + test_context: dict[str, str], + ) -> None: + """Positive: recovery_attempts increases by 1 each call.""" + await storage.create_saga(saga_id=saga_id, name="saga", context=test_context) + await storage.update_status(saga_id, SagaStatus.RUNNING) + + await storage.increment_recovery_attempts(saga_id) + _, ctx, ver = await storage.load_saga_state(saga_id) + assert storage._sagas[saga_id]["recovery_attempts"] == 1 + + await storage.increment_recovery_attempts(saga_id) + assert storage._sagas[saga_id]["recovery_attempts"] == 2 + + async def test_increment_recovery_attempts_updates_updated_at( + self, + storage: MemorySagaStorage, + saga_id: uuid.UUID, + test_context: dict[str, str], + ) -> None: + """Positive: updated_at is set to now.""" + await storage.create_saga(saga_id=saga_id, name="saga", context=test_context) + await storage.update_status(saga_id, SagaStatus.RUNNING) + before = storage._sagas[saga_id]["updated_at"] + + await storage.increment_recovery_attempts(saga_id) + after = storage._sagas[saga_id]["updated_at"] + assert after >= before + + async def test_increment_recovery_attempts_with_new_status( + self, + storage: MemorySagaStorage, + saga_id: uuid.UUID, + test_context: dict[str, str], + ) -> None: + """Positive: optional new_status updates saga status.""" + await storage.create_saga(saga_id=saga_id, name="saga", context=test_context) + await storage.update_status(saga_id, SagaStatus.RUNNING) + + await storage.increment_recovery_attempts(saga_id, new_status=SagaStatus.FAILED) + status, _, _ = await storage.load_saga_state(saga_id) + assert status == SagaStatus.FAILED + + # --- increment_recovery_attempts: negative --- + + async def test_increment_recovery_attempts_raises_when_saga_not_found( + self, + storage: MemorySagaStorage, + ) -> None: + """Negative: raises ValueError when saga_id does not exist.""" + unknown_id = uuid.uuid4() + with pytest.raises(ValueError, match="not found"): + await storage.increment_recovery_attempts(unknown_id) diff --git a/tests/integration/test_saga_storage_sqlalchemy.py b/tests/integration/test_saga_storage_sqlalchemy.py index a4d92c6..124a6ec 100644 --- a/tests/integration/test_saga_storage_sqlalchemy.py +++ b/tests/integration/test_saga_storage_sqlalchemy.py @@ -1,13 +1,20 @@ """Integration tests for SqlAlchemySagaStorage.""" +import asyncio import uuid +from collections.abc import AsyncGenerator import pytest +from sqlalchemy import delete from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from cqrs.dispatcher.exceptions import SagaConcurrencyError from cqrs.saga.storage.enums import SagaStatus, SagaStepStatus -from cqrs.saga.storage.sqlalchemy import SqlAlchemySagaStorage +from cqrs.saga.storage.sqlalchemy import ( + SagaExecutionModel, + SagaLogModel, + SqlAlchemySagaStorage, +) # Fixtures init_saga_orm and saga_session_factory are imported from tests/integration/fixtures.py @@ -266,3 +273,203 @@ async def test_optimistic_locking( _, final_context, final_version = await storage.load_saga_state(saga_id) assert final_context == new_context assert final_version == 2 + + +class TestRecoverySqlAlchemy: + """Integration tests for get_sagas_for_recovery and increment_recovery_attempts (SqlAlchemy).""" + + @pytest.fixture(autouse=True) + async def _clean_saga_tables( + self, + saga_session_factory: async_sessionmaker[AsyncSession], + ) -> AsyncGenerator[None, None]: + """Clear saga tables before each test so get_sagas_for_recovery sees only this test's data.""" + async with saga_session_factory() as session: + await session.execute(delete(SagaLogModel)) + await session.execute(delete(SagaExecutionModel)) + await session.commit() + yield + + # --- get_sagas_for_recovery: positive --- + + async def test_get_sagas_for_recovery_returns_recoverable_sagas( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: returns RUNNING, COMPENSATING, FAILED sagas only.""" + id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + for sid in (id1, id2, id3): + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(id1, SagaStatus.RUNNING) + await storage.update_status(id2, SagaStatus.COMPENSATING) + await storage.update_status(id3, SagaStatus.FAILED) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert set(ids) == {id1, id2, id3} + assert len(ids) == 3 + + async def test_get_sagas_for_recovery_respects_limit( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: returns at most `limit` saga IDs.""" + for _ in range(5): + sid = uuid.uuid4() + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.RUNNING) + + ids = await storage.get_sagas_for_recovery(limit=2) + assert len(ids) == 2 + + async def test_get_sagas_for_recovery_respects_max_recovery_attempts( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: only returns sagas with recovery_attempts < max_recovery_attempts.""" + id_low = uuid.uuid4() + id_high = uuid.uuid4() + await storage.create_saga(saga_id=id_low, name="saga", context=test_context) + await storage.create_saga(saga_id=id_high, name="saga", context=test_context) + await storage.update_status(id_low, SagaStatus.RUNNING) + await storage.update_status(id_high, SagaStatus.RUNNING) + for _ in range(5): + await storage.increment_recovery_attempts(id_high) + + ids = await storage.get_sagas_for_recovery(limit=10, max_recovery_attempts=5) + assert id_low in ids + assert id_high not in ids + + async def test_get_sagas_for_recovery_ordered_by_updated_at( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: result ordered by updated_at ascending (oldest first).""" + id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + for sid in (id1, id2, id3): + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.RUNNING) + # Ensure id2 has a strictly later updated_at (DB may use second precision). + await asyncio.sleep(1.0) + await storage.update_context(id2, {**test_context, "touched": True}) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert len(ids) == 3 + assert ids[-1] == id2 + + async def test_get_sagas_for_recovery_stale_after_excludes_recently_updated( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Positive: with stale_after_seconds, recently updated sagas are excluded.""" + id_recent = uuid.uuid4() + await storage.create_saga(saga_id=id_recent, name="saga", context=test_context) + await storage.update_status(id_recent, SagaStatus.RUNNING) + ids = await storage.get_sagas_for_recovery( + limit=10, + stale_after_seconds=999999, + ) + assert id_recent not in ids + + async def test_get_sagas_for_recovery_without_stale_after_unchanged_behavior( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Backward compat: without stale_after_seconds, recently updated sagas are included.""" + sid = uuid.uuid4() + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.RUNNING) + ids = await storage.get_sagas_for_recovery(limit=10) + assert sid in ids + + # --- get_sagas_for_recovery: negative --- + + async def test_get_sagas_for_recovery_empty_when_none_recoverable( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Negative: returns empty list when no recoverable sagas.""" + sid = uuid.uuid4() + await storage.create_saga(saga_id=sid, name="saga", context=test_context) + await storage.update_status(sid, SagaStatus.COMPLETED) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert ids == [] + + async def test_get_sagas_for_recovery_excludes_pending_and_completed( + self, + storage: SqlAlchemySagaStorage, + test_context: dict[str, str], + ) -> None: + """Negative: PENDING and COMPLETED sagas are not returned.""" + id_pending = uuid.uuid4() + id_completed = uuid.uuid4() + await storage.create_saga(saga_id=id_pending, name="saga", context=test_context) + await storage.create_saga( + saga_id=id_completed, + name="saga", + context=test_context, + ) + await storage.update_status(id_completed, SagaStatus.COMPLETED) + + ids = await storage.get_sagas_for_recovery(limit=10) + assert id_pending not in ids + assert id_completed not in ids + + # --- increment_recovery_attempts: positive --- + + async def test_increment_recovery_attempts_increments_counter( + self, + storage: SqlAlchemySagaStorage, + saga_id: uuid.UUID, + test_context: dict[str, str], + ) -> None: + """Positive: recovery_attempts increases; saga drops out after max_recovery_attempts.""" + await storage.create_saga(saga_id=saga_id, name="saga", context=test_context) + await storage.update_status(saga_id, SagaStatus.RUNNING) + + ids_before = await storage.get_sagas_for_recovery( + limit=10, + max_recovery_attempts=5, + ) + assert saga_id in ids_before + + for _ in range(5): + await storage.increment_recovery_attempts(saga_id) + + ids_after = await storage.get_sagas_for_recovery( + limit=10, + max_recovery_attempts=5, + ) + assert saga_id not in ids_after + + async def test_increment_recovery_attempts_with_new_status( + self, + storage: SqlAlchemySagaStorage, + saga_id: uuid.UUID, + test_context: dict[str, str], + ) -> None: + """Positive: optional new_status updates saga status.""" + await storage.create_saga(saga_id=saga_id, name="saga", context=test_context) + await storage.update_status(saga_id, SagaStatus.RUNNING) + + await storage.increment_recovery_attempts(saga_id, new_status=SagaStatus.FAILED) + status, _, _ = await storage.load_saga_state(saga_id) + assert status == SagaStatus.FAILED + + # --- increment_recovery_attempts: negative --- + + async def test_increment_recovery_attempts_raises_when_saga_not_found( + self, + storage: SqlAlchemySagaStorage, + ) -> None: + """Negative: raises ValueError when saga_id does not exist.""" + unknown_id = uuid.uuid4() + with pytest.raises(ValueError, match="not found"): + await storage.increment_recovery_attempts(unknown_id)