diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py index ddd22d90..bd16530a 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py @@ -7,35 +7,27 @@ from typing import Optional +from pydantic import BaseModel + from microsoft_agents.activity import Activity from ...storage._type_aliases import JSON from ...storage import StoreItem -class _SignInState(StoreItem): +class _SignInState(BaseModel, StoreItem): """Store item for sign-in state, including tokens and continuation activity. Used to cache tokens and keep track of activities during single and multi-turn sign-in flows. """ - def __init__( - self, - active_handler_id: str, - continuation_activity: Optional[Activity] = None, - ) -> None: - self.active_handler_id = active_handler_id - self.continuation_activity = continuation_activity + active_handler_id: str + continuation_activity: Optional[Activity] = None def store_item_to_json(self) -> JSON: - return { - "active_handler_id": self.active_handler_id, - "continuation_activity": self.continuation_activity, - } + return self.model_dump(mode="json", exclude_unset=True, by_alias=True) @staticmethod def from_json_to_store_item(json_data: JSON) -> _SignInState: - return _SignInState( - json_data["active_handler_id"], json_data.get("continuation_activity") - ) + return _SignInState.model_validate(json_data) diff --git a/tests/hosting_core/app/_oauth/_common.py b/tests/hosting_core/app/_oauth/_common.py index c2a6d2f0..4d3fea34 100644 --- a/tests/hosting_core/app/_oauth/_common.py +++ b/tests/hosting_core/app/_oauth/_common.py @@ -29,9 +29,11 @@ def create_testing_TurnContext( turn_context = mocker.Mock() if not activity: - turn_context.activity.channel_id = channel_id - turn_context.activity.from_property.id = user_id - turn_context.activity.type = ActivityTypes.message + turn_context.activity = Activity( + type=ActivityTypes.message, + channel_id=channel_id, + from_property={"id": user_id}, + ) else: turn_context.activity = activity turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" diff --git a/tests/hosting_core/app/_oauth/test_sign_in_state.py b/tests/hosting_core/app/_oauth/test_sign_in_state.py new file mode 100644 index 00000000..d3bce6ec --- /dev/null +++ b/tests/hosting_core/app/_oauth/test_sign_in_state.py @@ -0,0 +1,43 @@ +import json + +from microsoft_agents.activity import Activity +from microsoft_agents.hosting.core.app.oauth._sign_in_state import _SignInState + + +def test_sign_in_state_serialization_deserialization(tmp_path): + original_state = _SignInState( + active_handler_id="handler_123", + continuation_activity=Activity( + type="message", + id="activity_456", + timestamp="2024-01-01T12:00:00Z", + service_url="https://service.url", + channel_id="channel_789", + from_property={"id": "user_1"}, + conversation={"id": "conv_1"}, + recipient={"id": "bot_1"}, + text="Hello, World!", + ), + ) + + # Serialize to JSON + json_data = original_state.store_item_to_json() + + # Deserialize back to _SignInState + deserialized_state = _SignInState.from_json_to_store_item(json_data) + + # Assert equality + assert deserialized_state.active_handler_id == original_state.active_handler_id + assert ( + deserialized_state.continuation_activity == original_state.continuation_activity + ) + + with open(tmp_path / "sign_in_state.json", "w") as f: + json.dump(json_data, f) + + with open(tmp_path / "sign_in_state.json", "r") as f: + loaded_json_data = json.load(f) + + loaded_state = _SignInState.from_json_to_store_item(loaded_json_data) + assert loaded_state.active_handler_id == original_state.active_handler_id + assert loaded_state.continuation_activity == original_state.continuation_activity