Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ maintainers = [{name = "Vadim Kozyrevskiy", email = "vadikko2@mail.ru"}]
name = "python-cqrs"
readme = "README.md"
requires-python = ">=3.10"
version = "4.6.4"
version = "4.6.5"

[project.optional-dependencies]
aiobreaker = ["aiobreaker>=0.3.0"]
Expand Down
13 changes: 7 additions & 6 deletions src/cqrs/saga/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import typing
from dataclass_wizard import asdict, fromdict

# Type variable for from_dict classmethod return type
_T = typing.TypeVar("_T", bound="SagaContext")
Expand Down Expand Up @@ -29,7 +30,7 @@ def to_dict(self) -> dict[str, typing.Any]:
Returns:
Dictionary representation of the context.
"""
return dataclasses.asdict(self)
return asdict(self)

@classmethod
def from_dict(cls: type[_T], data: dict[str, typing.Any]) -> _T:
Expand All @@ -42,11 +43,11 @@ def from_dict(cls: type[_T], data: dict[str, typing.Any]) -> _T:
Returns:
Instance of the context class.
"""
# Get field names from dataclass
field_names = {f.name for f in dataclasses.fields(cls)}
# Filter data to only include known fields
filtered_data = {k: v for k, v in data.items() if k in field_names}
return cls(**filtered_data)
# # Get field names from dataclass
# field_names = {f.name for f in dataclasses.fields(cls)}
# # Filter data to only include known fields
# filtered_data = {k: v for k, v in data.items() if k in field_names}
return fromdict(cls, data)

def model_dump(self) -> dict[str, typing.Any]:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/cqrs/saga/saga.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ async def __aiter__(

yield step_result

# Update context one final time before marking as completed
await self._state_manager.update_context(self._context)
await self._state_manager.update_status(SagaStatus.COMPLETED)

except Exception as e:
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/test_saga/test_saga_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ class TestSaga(Saga[OrderContext]):

saga = TestSaga()

# Should raise TypeError when required fields are missing
with pytest.raises(TypeError):
# Should raise MissingFields when required fields are missing
from dataclass_wizard.errors import MissingFields

with pytest.raises(MissingFields):
await recover_saga(saga, saga_id, OrderContext, saga_container, storage)

assert not reserve_step.act_called
Expand Down Expand Up @@ -485,7 +487,7 @@ class TestSaga(Saga[OrderContext]):
# Verify context was updated in storage
status, updated_context, _ = await storage.load_saga_state(saga_id)
assert status == SagaStatus.COMPLETED
assert updated_context["order_id"] == "123"
assert updated_context["orderId"] == "123"


async def test_recover_saga_with_mock_storage_exception(
Expand Down