diff --git a/pyproject.toml b/pyproject.toml index 9de1a16..fd63210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ maintainers = [{name = "Vadim Kozyrevskiy", email = "vadikko2@mail.ru"}] name = "python-cqrs" readme = "README.md" requires-python = ">=3.10" -version = "4.6.4" +version = "4.6.5" [project.optional-dependencies] aiobreaker = ["aiobreaker>=0.3.0"] diff --git a/src/cqrs/saga/models.py b/src/cqrs/saga/models.py index a41e8f6..01335a9 100644 --- a/src/cqrs/saga/models.py +++ b/src/cqrs/saga/models.py @@ -1,5 +1,6 @@ import dataclasses import typing +from dataclass_wizard import asdict, fromdict # Type variable for from_dict classmethod return type _T = typing.TypeVar("_T", bound="SagaContext") @@ -29,7 +30,7 @@ def to_dict(self) -> dict[str, typing.Any]: Returns: Dictionary representation of the context. """ - return dataclasses.asdict(self) + return asdict(self) @classmethod def from_dict(cls: type[_T], data: dict[str, typing.Any]) -> _T: @@ -42,11 +43,11 @@ def from_dict(cls: type[_T], data: dict[str, typing.Any]) -> _T: Returns: Instance of the context class. """ - # Get field names from dataclass - field_names = {f.name for f in dataclasses.fields(cls)} - # Filter data to only include known fields - filtered_data = {k: v for k, v in data.items() if k in field_names} - return cls(**filtered_data) + # # Get field names from dataclass + # field_names = {f.name for f in dataclasses.fields(cls)} + # # Filter data to only include known fields + # filtered_data = {k: v for k, v in data.items() if k in field_names} + return fromdict(cls, data) def model_dump(self) -> dict[str, typing.Any]: """ diff --git a/src/cqrs/saga/saga.py b/src/cqrs/saga/saga.py index f2e0c41..006efe4 100644 --- a/src/cqrs/saga/saga.py +++ b/src/cqrs/saga/saga.py @@ -302,6 +302,8 @@ async def __aiter__( yield step_result + # Update context one final time before marking as completed + await self._state_manager.update_context(self._context) await self._state_manager.update_status(SagaStatus.COMPLETED) except Exception as e: diff --git a/tests/unit/test_saga/test_saga_recovery.py b/tests/unit/test_saga/test_saga_recovery.py index b33b48a..07f889c 100644 --- a/tests/unit/test_saga/test_saga_recovery.py +++ b/tests/unit/test_saga/test_saga_recovery.py @@ -272,8 +272,10 @@ class TestSaga(Saga[OrderContext]): saga = TestSaga() - # Should raise TypeError when required fields are missing - with pytest.raises(TypeError): + # Should raise MissingFields when required fields are missing + from dataclass_wizard.errors import MissingFields + + with pytest.raises(MissingFields): await recover_saga(saga, saga_id, OrderContext, saga_container, storage) assert not reserve_step.act_called @@ -485,7 +487,7 @@ class TestSaga(Saga[OrderContext]): # Verify context was updated in storage status, updated_context, _ = await storage.load_saga_state(saga_id) assert status == SagaStatus.COMPLETED - assert updated_context["order_id"] == "123" + assert updated_context["orderId"] == "123" async def test_recover_saga_with_mock_storage_exception(