From e969eb640833169c2722c112a721334ea400d822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Tue, 12 Aug 2025 13:52:57 -0700 Subject: [PATCH 01/32] Changed default cloud for ConnectionSettings to be PROD --- .../agents/copilotstudio/client/connection_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libraries/microsoft-agents-copilotstudio-client/microsoft/agents/copilotstudio/client/connection_settings.py b/libraries/microsoft-agents-copilotstudio-client/microsoft/agents/copilotstudio/client/connection_settings.py index 6e8bd61d..6b8e6bc7 100644 --- a/libraries/microsoft-agents-copilotstudio-client/microsoft/agents/copilotstudio/client/connection_settings.py +++ b/libraries/microsoft-agents-copilotstudio-client/microsoft/agents/copilotstudio/client/connection_settings.py @@ -27,6 +27,6 @@ def __init__( if not self.agent_identifier: raise ValueError("Agent Identifier must be provided") - self.cloud = cloud or PowerPlatformCloud.UNKNOWN + self.cloud = cloud or PowerPlatformCloud.PROD self.copilot_agent_type = copilot_agent_type or AgentType.PUBLISHED self.custom_power_platform_cloud = custom_power_platform_cloud From 2c7cbb297b26d9fc20a20717ef26c1096e64b6ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Tue, 12 Aug 2025 17:07:13 -0700 Subject: [PATCH 02/32] New draft implementations for classes to rework authorization --- .../hosting/core/app/oauth/sign_in_context.py | 208 ++++++++++++++++++ .../hosting/core/app/oauth/sign_in_storage.py | 72 ++++++ 2 files changed, 280 insertions(+) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py new file mode 100644 index 00000000..7174b189 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py @@ -0,0 +1,208 @@ +import logging +from typing import Optional, Callable + +from .sign_in_storage import SignInStorage, SignInHandlerState, SignInHandlerStateStatus, FlowState + +logger = logging.getLogger(__name__) + +ms_agents_logger = logging.getLogger("microsoft.agents") +handler_formatter = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)")) +ms_agents_logger.addHandler(console_handler) +ms_agents_logger.setLevel(logging.INFO) + + + +class SignInContext: + + logger = logging.getLogger(f"{__name__}.SignInContext") # robrandao: TODO get logger with config + + def __init__(self, + storage: SignInStorage, + auth_handlers: AuthHandlers, + context: TurnContext, + handler_id: str = "", + is_started_from_route: bool = True): + + if not is_started_from_route and not handler_id: + raise ValueError("handler_id must be provided when is_started_from_route is False.") + + if not hasattr(context, "activity"): # robrandao: TODO -> see extra condition in JS code + raise ValueError("context must have an activity property.") + + # robrandao: TODO -> is this necessary here, or can we do this outside? Can we make the storage outside too? + # if self.is_started_from_route: + + # robrandao: TODO type signature + def on_success(self, handler: Callable) -> None: + self.__on_success_handler = handler + + def on_failure(self, handler: Callable) -> None: + self.__on_failure_handler = handler + + async def get_token(self) -> Optional[TokenResponse]: + + if not await self.load_handler(): + return TokenResponse() + + self.logger.info("Getting token from user token service.") + return self.__auth_handler.flow.get_token(self.context) + + async def exchange_token(self, scopes: list[str]) -> TokenResponse: + if not await self.load_handler(): + return TokenResponse() + + self.logger.info("Exchanging token from user token service.") + token_response = await self.__auth_handler.flow.get_token(self.context) + if self.is_exchangeable(token_response.token): + return await self.handle_obo(token_response.token, scopes) + return token_response + + async def sign_out(self) -> None: + if not await self.load_handler(): + return + + self.logger.info("Signing out from the authorization flow.") + if self.is_started_from_route: + await self.storage.handler_delete(self.handler.id) + return self.__auth_handler.flow.sign_out(self.context) + + async def get_token(self) -> Optional[TokenResponse]: + if not await self.load_handler(): + return TokenResponse() + + self.logger.debug("Processing authorization flow.") + self.logger.debug(f"Uses Storage state: {self.is_started_from_route}") + self.logger.debug("Current sign-in state:", self.handler) + + token_response = await self.handler.status( + {} # robrandao: TODO + ) + + self.logger.debug("OAuth flow result: %s", { token: token_response.get(token), state: self.handler}) + return token_response + + DEFAULT_STATES: dict[SignInHandlerStateStatus, FlowState] = { + "begin": lambda: { id: self.handler_id, status: status} + } + + def __set_status(status: SignInHandlerStateStatus) -> None: + + # robrandao: TODO - type + state_builder: dict[str, Callable] = { + SignInHandlerStateStatus.BEGIN: (lambda: { + "id": self.handler_id, + "status": SignInHandlerStateStatus.BEGIN + }), + SignInHandlerStateStatus.CONTINUE: (lambda self: { + **self.handler, + "status": SignInHandlerStateStatus.CONTINUE, + "state": self.flow_state + "continuation_activity": self.context.activity + }), + SignInHandlerStatus.SUCCESS: (lambda: { + **self.handler, + "status": SignInHandlerStateStatus.SUCCESS, + "state": None + }), + SignInHandlerStateStatus.FAILURE: (lambda: { + **self.handler, + "status": SignInHandlerStateStatus.FAILURE, + "state": self.flow_state + }) + } + + self.__handler = state_builder[status]() + return self.__handler # robrandao: TODO ??? + + async def __load_handler(self) -> bool: + if self.is_started_from_route: + if self.handler_id: + self.__handler = await self.storage.handler_get(self.handler_id) + else: + self.__handler = await self.storage.handler_active() + + if not self.__handler: + # robrandao: TODO renaming? + self.__handler = self.__set_status(SignInHandlerStateStatus.NOT_STARTED, None) + + if not self.handler.id: + return False + + self.__auth_handler = self.get_auth_handler_or_throw(self.handler.id) + + # robrandao: TODO + if not self.is_started_from_route and self.flow_state.flow_started: + self.set_status(SignInHandlerStateStatus.SUCCESS) + self.logger.debug("OAuth flow success, using existing state.") + return True + + if self.handler.status == SignInHandlerStateStatus.BEGIN: + self.logger.debug("No active flow state, starting a new OAuth flow.") + await self.__auth_handler.flow.sign_out(self.context) + else: + await self.__auth_handler.flow.set_flow_state(self.context, self.handler.state or FlowState()) # robrandao: TODO + + return True + + async def begin(self) -> None: + self.logger.debug("Beginning OAuth flow.") + await self.__auth_handler.flow.begin_flow(self.context) + self.logger.debug("OAuth flow started, waiting on continuation...") + self.__set_status(SignInHandlerStateStatus.CONTINUE) + if self.is_started_from_route: + await self.storage.handler_set(self.handler) + + async def continue(self) -> Optional[TokenResponse]: + + self.logger.debug("Continuing OAuth flow.") + + token_response = await self.__auth_handler.flow.continue_flow(self.context) + if token_response.token: + self.set_status(SignInHandlerStateStatus.SUCCESS) + self.logger.debug("OAuth flow success.") + if self.is_started_from_route: + await self.storage.handler_set(self.handler) + if self.__on_success_handler: + await self.__on_success_handler() + + else: + await self.failure() + + return token_response + + async def success(self) -> Optional[TokenResponse]: + token_response = await self.__auth_handler.flow.get_token(self.context) + # robrandao: TODO -> JS always strips() the token? + if self.is_started_from_route and token_response.token: + self.logger.debug("OAuth flow success, retrieving token.") + return token_response + else: + self.logger.debug("OAuth flow token not available, waiting on continuation...") + return self.continue() + + async def failure(self) -> None: + self.__set_status(SignInHandlerStateStatus.FAILURE) + + # TODO + + async def __is_exchangeable(self, token: Optional[str]) -> bool: + if not token or not isinstance(token, str): # robrandao: TODO ??? + return False + + payload = JwtToken.parse(token).payload # robrandao: TODO + return payload.aud.index("api://") == 0 + + async def __handle_obo(self, token: str, scopes: list[str]) -> TokenResponse: + msal_token_provider = MsalTokenProvider() + + auth_config = self.context.adapter.auth_config + if self.__auth_handler.cnx_prefix: + auth_config = load_auth_config_from_env(self.__auth_handler.cnx_prefix) + + new_token = await msal_token_provider.get_on_behalf_of_token(auth_config, scopes, token) + return TokenResponse(token) + + async def __get_auth_handler_or_throw(handler_id: str) -> AuthHandler: + # robrandao: TODO + pass \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py new file mode 100644 index 00000000..ae64443e --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py @@ -0,0 +1,72 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from microsoft.agents.activity import Activity +from microsoft.agents.hosting.core import StoreItem + +class FlowState: + pass + +class SignInHandlerStateStatus(Enum): + NOT_STARTED = "not_started" + CONTINUE = "in_progress" + COMPLETED = "completed" + FAILURE = "failure" + + +class SignInHandlerState(BaseModel, StoreItem): + id: str + status: SignInHandlerStateStatus + continuation_activity: Activity + + def store_item_to_json(self): # todo + return super().store_item_to_json() + + @staticmethod + def from_json_to_store_item(json_data): # todo + return super().from_json_to_store_item(json_data) + + +class SignInStorage: + + def __init__(self, context: TurnContext, storage: Storage, handlers: Optional[list[AuthHandler]] = None): + + if (not context.activity or + not context.activity.channel_id or + not context.activity.from_property or + not context.activity.from_property.id): + + raise ValueError("context.activity -> channel_id and from.id must be set.") + + channel_id = context.activity.channel_id + user_id = context.activity.from_property.id + + self.__base_key = f"auth/{channel_id}/{user_id}" + self.__handlers = handlers or [] + self.__handler_keys = list(map(create_key, self.__handlers)) + self.__storage = storage + + def create_key(self, id: str) -> str: + if not self.__base_key: + raise AttributeError # robrandao: TODO + return f"{self.__base_key}/${id}" + + async def active(self) -> Optional[SignInHandlerState]: + # batched reads would make this more efficient + for handler_key in self.__handler_keys: + state = (await self.__storage.read([handler_key], SignInHandlerState)).get(handler_key) + if state and state.status == SignInHandlerStateStatus.IN_PROGRESS: + return state + + async def get(self, id: str) -> Optional[SignInHandlerState]: + key = self.create_key(id) + data = await self.__storage.read([key], SignInHandlerState) + return data.get(key) # robrandao: TODO -> verify contract + + async def set(self, value: SignInHandlerState) -> None: + await self.__storage.write({ self.create_key(value.id): value}) + + async def delete(self, id: str) -> None: + await self.__storage.delete([self.create_key(id)]) \ No newline at end of file From a5473783b698169a2507dc50fecb6f8fea9f68c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Thu, 14 Aug 2025 15:48:59 -0700 Subject: [PATCH 03/32] Revising oauth refactor with new proposed architecture --- .../hosting/core/app/agent_application.py | 119 ++++++--- .../{sign_in_context.py => auth_context.py} | 22 +- .../hosting/core/app/oauth/auth_flow.py | 197 ++++++++++++++ .../hosting/core/app/oauth/auth_handler.py | 39 +++ .../hosting/core/app/oauth/authorization.py | 248 ++++++------------ .../core/app/oauth/flow_storage_client.py | 57 ++++ .../agents/hosting/core/app/oauth/models.py | 57 ++++ .../hosting/core/app/oauth/sign_in_storage.py | 72 ----- 8 files changed, 537 insertions(+), 274 deletions(-) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{sign_in_context.py => auth_context.py} (93%) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 142f311c..d7f77d0c 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -42,8 +42,7 @@ from .route import Route, RouteHandler from .state import TurnState from ..channel_service_adapter import ChannelServiceAdapter -from .oauth import Authorization, SignInState -from .typing_indicator import TypingIndicator +from .oauth import Authorization, SignInState, FlowResponse, FlowStateTag logger = logging.getLogger(__name__) @@ -591,6 +590,87 @@ def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): logger.debug(f"Setting custom turn state factory: {func.__name__}") self._turn_state_factory = func return func + + async def _handle_flow_response(self, context: TurnContext, flow_response: FlowResponse) -> None: + + flow_state: FlowState = flow_response.flow_state + in_flow_activity = flow_response.in_flow_activity + + if in_flow_activity: + context.send_activity(in_flow_activity) + + if flow_state.tag == FlowStateTag.BEGIN: + # Create the OAuth card + o_card: Attachment = CardFactory.oauth_card( + OAuthCard( + text=self.messages_configuration.get("card_title", "Sign in"), + connection_name=self.abs_oauth_connection_name, + buttons=[ + CardAction( + title=self.messages_configuration.get("button_text", "Sign in"), + type=ActionTypes.signin, + value=signing_resource.sign_in_link, + channel_data=None, + ) + ], + token_exchange_resource=signing_resource.token_exchange_resource, + token_post_resource=signing_resource.token_post_resource, + ) + ) + + # Send the card to the user + await context.send_activity(MessageFactory.attachment(o_card)) + elif flow_state.tag == FlowStateTag.FAILURE: + if flow_state.reached_max_retries(): + await context.send_activity( + MessageFactory.text( + self.messages_configuration.get( + "max_retries_reached_messages", + "Sign-in failed. Please try again later.", + ) + ) + ) + elif flow_state.is_expired(): + await context.send_activity( + MessageFactory.text( + self.messages_configuration.get( + "session_expired_messages", + "Sign-in session expired. Please try again.", + ) + )) + else: + logger.warning("Sign-in flow failed for unknown reasons.") + await context.send_activity("Sign-in failed. Please try again.") + + async def _on_turn_auth_intercept(self, context: TurnContext, turn_state) -> bool: + + prev_flow_state = self._auth.get_active_flow_state(context) + if self._auth and prev_flow_state: + + logger.debug("Sign-in flow is active for context: %s", context.activity.id) + + flow_response: FlowResponse = await self._auth.begin_or_continue_flow( + context, turn_state, prev_flow_state.handler_id + ) + + self._handle_flow_response(flow_response) + + new_flow_state: FlowState = flow_response.flow_state + token_response: TokenResponse = new_flow_state.token_response + saved_activity: Activity = new_flow_state.continuation_activity.model_copy() + + if token_response and token_response.token: + new_context = copy(context) + new_context.activity = saved_activity + logger.info( + "Resending continuation activity %s", saved_activity.text + ) + await self.on_turn(new_context) + turn_state.delete_value(Authorization.SIGN_IN_STATE_KEY) + await turn_state.save(context) + return True + + return False async def on_turn(self, context: TurnContext): logger.debug( @@ -599,6 +679,7 @@ async def on_turn(self, context: TurnContext): await self._start_long_running_call(context, self._on_turn) async def _on_turn(self, context: TurnContext): + # robrandao: TODO try: await self._start_typing(context) @@ -607,32 +688,8 @@ async def _on_turn(self, context: TurnContext): logger.debug("Initializing turn state") turn_state = await self._initialize_state(context) - sign_in_state = turn_state.get_value( - Authorization.SIGN_IN_STATE_KEY, target_cls=SignInState - ) - logger.debug( - f"Sign-in state: {sign_in_state} for context: {context.activity.id}" - ) - - if self._auth and sign_in_state and not sign_in_state.completed: - flow_state = self._auth.get_flow_state(sign_in_state.handler_id) - logger.debug("Flow state: %s", flow_state) - if flow_state.flow_started: - logger.debug("Continuing sign-in flow") - token_response = await self._auth.begin_or_continue_flow( - context, turn_state, sign_in_state.handler_id - ) - saved_activity = sign_in_state.continuation_activity.model_copy() - if token_response and token_response.token: - new_context = copy(context) - new_context.activity = saved_activity - logger.info( - "Resending continuation activity %s", saved_activity.text - ) - await self.on_turn(new_context) - turn_state.delete_value(Authorization.SIGN_IN_STATE_KEY) - await turn_state.save(context) - return + if self._on_turn_auth_intercept(context): + return logger.debug("Running before turn middleware") if not await self._run_before_turn_middleware(context, turn_state): @@ -753,12 +810,14 @@ async def _on_activity(self, context: TurnContext, state: StateT): else: sign_in_complete = False for auth_handler_id in route.auth_handlers: - token_response = await self._auth.begin_or_continue_flow( + flow_response: FlowResponse = await self._auth.begin_or_continue_flow( context, state, auth_handler_id ) - sign_in_complete = token_response and token_response.token + self._handle_flow_response(context, flow_response.in_flow_activity) + sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE if not sign_in_complete: break + if sign_in_complete: await route.handler(context, state) return diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py similarity index 93% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py index 7174b189..566d1190 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_context.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py @@ -11,9 +11,7 @@ ms_agents_logger.addHandler(console_handler) ms_agents_logger.setLevel(logging.INFO) - - -class SignInContext: +class AuthContext: logger = logging.getLogger(f"{__name__}.SignInContext") # robrandao: TODO get logger with config @@ -30,6 +28,12 @@ def __init__(self, if not hasattr(context, "activity"): # robrandao: TODO -> see extra condition in JS code raise ValueError("context must have an activity property.") + self.__storage = AuthStateStorage(context, storage) + + sign_in_storage = SignInStorage( + context, self.__storage, self.__auth_handlers + ) # robrandao: TODO + # robrandao: TODO -> is this necessary here, or can we do this outside? Can we make the storage outside too? # if self.is_started_from_route: @@ -88,6 +92,15 @@ async def get_token(self) -> Optional[TokenResponse]: def __set_status(status: SignInHandlerStateStatus) -> None: + if status == FlowProgression.BEGIN: + self.flow_state = FlowState(id) + elif status == FlowProgression.CONTINUE: + pass + elif status == FlowProgression.SUCCESS: + pass + elif status == FlowProgression.FAILURE: + pass + # robrandao: TODO - type state_builder: dict[str, Callable] = { SignInHandlerStateStatus.BEGIN: (lambda: { @@ -205,4 +218,5 @@ async def __handle_obo(self, token: str, scopes: list[str]) -> TokenResponse: async def __get_auth_handler_or_throw(handler_id: str) -> AuthHandler: # robrandao: TODO - pass \ No newline at end of file + pass + diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py new file mode 100644 index 00000000..27cb7e70 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations + +import logging + +from enum import Enum +from datetime import datetime +from typing import Dict, Optional + +from microsoft.agents.hosting.core.connector.client import UserTokenClient +from microsoft.agents.activity import ( + ActionTypes, + ActivityTypes, + CardAction, + Attachment, + OAuthCard, + TokenExchangeState, + TokenResponse, + Activity, +) +from microsoft.agents.activity import ( + TurnContextProtocol as TurnContext, +) +from microsoft.agents.hosting.core.storage import StoreItem, Storage +from pydantic import BaseModel, PositiveInt + +from .message_factory import MessageFactory +from .card_factory import CardFactory +from .models import FlowResponse, FlowState, FlowStateTag + +logger = logging.getLogger(__name__) + + +class AuthFlow: + """ + Manages the OAuth flow. + """ + + def __init__( + self, + flow_state: FlowState = None, + abs_oauth_connection_name: str = None, + user_token_client: Optional[UserTokenClient] = None, + **kwargs + ): + if not abs_oauth_connection_name: + raise ValueError( + "OAuthFlow.__init__: abs_oauth_connection_name required." + ) + + self.flow_state = flow_state or FlowState() + self.__abs_oauth_connection_name = abs_oauth_connection_name + self.__user_token_client = user_token_client + + async def __initialize_token_client(self, context: TurnContext) -> None: + # robrandao: TODO is this safe + # use cached value later + self.__user_token_client = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) + + async def __get_ids_or_raise(self, context: TurnContext) -> TokenResponse: + if ( + not not context.activity.channel_id or + not context.activity.from_property or + not context.activity.from_property.id + ): + raise ValueError("User ID or Channel ID is not set in the activity.") + + return context.activity.channel_id, context.activity.from_property.id + + async def __get_user_token(self, context: TurnContext, magic_code=None) -> TokenResponse: + channel_id, from_id = self.__get_ids_or_raise(context) + await self.__initialize_token_client(context) + + return await self.user_token_client.user_token.get_token( + user_id=from_id, + connection_name=self.__abs_oauth_connection_name, + channel_id=channel_id, + magic_code=magic_code + ) + + async def get_user_token(self, context: TurnContext) -> TokenResponse: + return self.__get_user_token(context) + + async def sign_out(self, context: TurnContext) -> None: + channel_id, from_id = self.__get_ids_or_raise(context) + await self.__initialize_token_client(context) + + return await self.__user_token_client.user_token.get_token( + user_id=from_id, + connection_name=self.__abs_oauth_connection_name, + channel_id=channel_id + ) + + async def __use_attempt(self, context: TurnContext) -> None: + if self.flow_state.attempts_remaining <= 0: + self.flow_state.flow_state_tag = FlowStateTag.FAILURE + + async def __failed_attempt(self, context: TurnContext) -> None: + pass + + async def begin_flow(self, context: TurnContext) -> FlowResponse: + + # init flow state + + token_response = self.get_user_token(context) + if token_response and token_response.token: + pass + + token_exchange_state = TokenExchangeState( + connection_name=self.__abs_oauth_connection_name, + conversation=context.activity.get_conversation_reference(), + relates_to=context.activity.relates_to, + ms_app_id=context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims["aud"] # robrandao: TODO + ) + + sign_in_resource = await self.__user_token_client.agent_sign_in.get_sign_in_resource(state=token_exchange_sate.get_encoded_state()) + + return FlowResponse(flow_state=self.flow_state) + + async def __continue_from_message(self, context: TurnContext) -> None: + + magic_code = activity.text + if magic_code and magic_code.isdigit() and len(magic_code) == 6: + result = self.__get_user_token(context, magic_code) + + if result and result.token: + return result + else: + return InvalidCodeError + else: + return InvalidCodeFormatError + + async def __continue_from_invoke_verify_state(self, context: TurnContext) -> None: + token_verify_sate = context.activity.value + magic_code = token_verify_state.get("state") + result = self.__get_user_token(context, magic_code) + if result and result.token: + pass + return None + + async def __continue_from_invoke_token_exchange(self, context: TurnContext) -> None: + await self.__initialize_token_client(context) + channel_id, from_id = self.__get_ids_or_raise(context) + + token_exchange_request = context.activity.value + token_exchange_id = token_exchange_request.get("id") + + return await self.__user_token_client.user_token.exchange_token( + user_id=context.activity.from_property.id, + connection_name=self.__abs_oauth_connection_name, + channel_id=channel_id, + body=token_exchange_request + ) + + + async def continue_flow(self, context: TurnContext) -> FlowResponse: + if self.flow_state.is_expired() or self.flow_state.reached_max_retries(): + self.flow_state.flow_state_tag = FlowStateTag.FAILURE + return FlowResponse(flow_state=self.flow_state) + + continue_flow_activity = context.activity + + if continue_flow_activity.type == ActivityTypes.message: + token_response, flow_error = continue_flow_from_message() + elif continue_flow_activity.type == ActivityTypes.invoke and continue_flow_activity.name == "signin/verifyState": + token_response, flow_error = continue_flow_from_invoke_verify_state() + elif continue_flow_activity.type == ActivityTypes.invoke and continue_flow_activity.name == "signin/tokenExchange": + token_response, flow_error = continue_flow_from_invoke_token_exchange() + else: + pass + + if flow_error != FlowError.NONE and token_response and token_response.token: + pass + elif flow_error == FlowError.NONE: + flow_error = + + pass + + + async def continue_flow(self, context: TurnContext) -> FlowResponse: + pass + + async def begin_or_continue_flow(self, context: TurnContext) -> FlowResponse: + + tag = self.flow_state.flow_state_tag + + if tag == FlowStateTag.CONTINUE: + self.continue_flow(context) + else: + self.begin_flow(context) + + if tag == FlowStateTag.BEGIN: + pass + elif tag == FlowStateTag.CONTINUE: + pass \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py new file mode 100644 index 00000000..c6d6697e --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py @@ -0,0 +1,39 @@ +class AuthHandler: + """ + Interface defining an authorization handler for OAuth flows. + """ + + def __init__( + self, + name: str = None, + title: str = None, + text: str = None, + abs_oauth_connection_name: str = None, + obo_connection_name: str = None, + **kwargs, + ): + """ + Initializes a new instance of AuthHandler. + + Args: + name: The name of the OAuth connection. + auto: Whether to automatically start the OAuth flow. + title: Title for the OAuth card. + text: Text for the OAuth button. + """ + self.name = name or kwargs.get("NAME") + self.title = title or kwargs.get("TITLE") + self.text = text or kwargs.get("TEXT") + self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( + "AZUREBOTOAUTHCONNECTIONNAME" + ) + self.obo_connection_name = obo_connection_name or kwargs.get( + "OBOCONNECTIONNAME" + ) + self.flow: OAuthFlow = None + logger.debug( + f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" + ) + +# Type alias for authorization handlers dictionary +AuthorizationHandlers = Dict[str, AuthHandler] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index b49ee253..b9eff281 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -5,6 +5,7 @@ import logging import jwt from typing import Dict, Optional, Callable, Awaitable +from collections.abc import Iterable from microsoft.agents.hosting.core.authorization import ( Connections, @@ -19,76 +20,21 @@ from ...app.state.turn_state import TurnState from ...oauth_flow import OAuthFlow, FlowState from ...state.user_state import UserState +from .sign_in_context import SignInContext +from .auth_handler import AuthHandler, AuthorizationHandlers +from .sign_in_state import SignInState -logger = logging.getLogger(__name__) - - -class SignInState(StoreItem, BaseModel): - """ - Interface defining the sign-in state for OAuth flows. - """ - - continuation_activity: Optional[Activity] = None - handler_id: Optional[str] = None - completed: Optional[bool] = False - - def store_item_to_json(self) -> dict: - return self.model_dump(exclude_unset=True) - - @staticmethod - def from_json_to_store_item(json_data: dict) -> "StoreItem": - return SignInState.model_validate(json_data) - - -class AuthHandler: - """ - Interface defining an authorization handler for OAuth flows. - """ - - def __init__( - self, - name: str = None, - title: str = None, - text: str = None, - abs_oauth_connection_name: str = None, - obo_connection_name: str = None, - **kwargs, - ): - """ - Initializes a new instance of AuthHandler. - - Args: - name: The name of the OAuth connection. - auto: Whether to automatically start the OAuth flow. - title: Title for the OAuth card. - text: Text for the OAuth button. - """ - self.name = name or kwargs.get("NAME") - self.title = title or kwargs.get("TITLE") - self.text = text or kwargs.get("TEXT") - self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( - "AZUREBOTOAUTHCONNECTIONNAME" - ) - self.obo_connection_name = obo_connection_name or kwargs.get( - "OBOCONNECTIONNAME" - ) - self.flow: OAuthFlow = None - logger.debug( - f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" - ) +from .storage import AuthStateStorage - -# Type alias for authorization handlers dictionary -AuthorizationHandlers = Dict[str, AuthHandler] +logger = logging.getLogger(__name__) class Authorization: """ Class responsible for managing authorization and OAuth flows. + Handles multiple OAuth providers and manages the complete authentication lifecycle. """ - SIGN_IN_STATE_KEY = f"{UserState.__name__}.__SIGNIN_STATE_" - def __init__( self, storage: Storage, @@ -107,11 +53,15 @@ def __init__( Raises: ValueError: If storage is None or no auth handlers are provided. """ - if storage is None: - logger.error("Storage is required for Authorization") + if not storage: raise ValueError("Storage is required for Authorization") + if not auth_handlers: + raise ValueError("At least one AuthHandler must be provided") user_state = UserState(storage) + + self.__auth_storage = AuthStateStorage(storage, ) + self._connection_manager = connection_manager auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( @@ -164,18 +114,14 @@ async def get_token( Gets the token for a specific auth handler. Args: - context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + context: The context object for the current turn. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. Returns: The token response from the OAuth provider. """ - auth_handler = self.resolver_handler(auth_handler_id) - if auth_handler.flow is None: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") - - return await auth_handler.flow.get_user_token(context) + flow = OAuthFlow(context, auth_handler_id) + return await flow.get_token() async def exchange_token( self, @@ -194,17 +140,26 @@ async def exchange_token( Returns: The token response from the OAuth provider. """ - auth_handler = self.resolver_handler(auth_handler_id) - if not auth_handler.flow: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") + flow = OAuthFlow(context, auth_handler_id) - token_response = await auth_handler.flow.get_user_token(context) + token_response = await flow.get_user_token(context) - if self._is_exchangeable(token_response.token if token_response else None): - return await self._handle_obo(token_response.token, scopes, auth_handler_id) + if self.__is_exchangeable(token_response.token if token_response else None): + pass - return token_response + return await flow.exchange_token(scopes) + + # auth_handler = self.resolver_handler(auth_handler_id) + # if not auth_handler.flow: + # logger.error("OAuth flow is not configured for the auth handler") + # raise ValueError("OAuth flow is not configured for the auth handler") + + # token_response = await auth_handler.flow.get_user_token(context) + + # if self._is_exchangeable(token_response.token if token_response else None): + # return await self._handle_obo(token_response.token, scopes, auth_handler_id) + + # return token_response def _is_exchangeable(self, token: Optional[str]) -> bool: """ @@ -265,31 +220,13 @@ async def _handle_obo( scopes=scopes, # Expiration can be set based on the token provider's response ) - def get_flow_state(self, auth_handler_id: Optional[str] = None) -> FlowState: - """ - Gets the current state of the OAuth flow. - - Args: - auth_handler_id: Optional ID of the auth handler to check, defaults to first handler. - - Returns: - The flow state object. - """ - flow = self.resolver_handler(auth_handler_id).flow - if flow is None: - # Return a default FlowState if no flow is configured - return FlowState() - - # Return flow state if available - return flow.flow_state or FlowState() - async def begin_or_continue_flow( self, context: TurnContext, turn_state: TurnState, auth_handler_id: str, sec_route: bool = True, - ) -> TokenResponse: + ) -> AuthFlowResponse: """ Begins or continues an OAuth flow. @@ -301,65 +238,35 @@ async def begin_or_continue_flow( Returns: The token response from the OAuth provider. """ - auth_handler = self.resolver_handler(auth_handler_id) - # Get or initialize sign-in state - sign_in_state = turn_state.get_value( - self.SIGN_IN_STATE_KEY, target_cls=SignInState - ) - if sign_in_state is None: - sign_in_state = SignInState( - continuation_activity=None, handler_id=None, completed=False - ) + # robrandao: TODO -> is_started_from_route and sec_route - flow = auth_handler.flow - if flow is None: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") + flow_storage_client = FlowStorageClient(context, self.__storage) + flow = self.__create_flow(context, auth_handler_id) - logger.info( - "Beginning or continuing OAuth flow for handler: %s", auth_handler_id + + sign_in_context: SignInContext = self.__create_sign_in_context( + context, auth_handler_id, 42 ) - token_response = await flow.get_user_token(context) - if token_response and token_response.token: - logger.debug("Token obtained successfully") - return token_response - - # Get the current flow state - flow_state = await flow._get_flow_state(context) - - if not flow_state.flow_started: - logger.info("Starting new OAuth flow for handler: %s", auth_handler_id) - token_response = await flow.begin_flow(context) - if sec_route: - sign_in_state.continuation_activity = context.activity - sign_in_state.handler_id = auth_handler_id - turn_state.set_value(self.SIGN_IN_STATE_KEY, sign_in_state) - else: - logger.info( - "Continuing existing OAuth flow for handler: %s", auth_handler_id + if self.__sign_in_success_handler: + sign_in_context.on_success( + lambda: self.__sign_in_success_handler( + context, turn_state, sign_in_context.handler.id + ) ) - token_response = await flow.continue_flow(context) - # Check if sign-in was successful and call handler if configured - if token_response and token_response.token: - if self._sign_in_handler: - logger.info("Sign-in successful, calling sign-in handler") - await self._sign_in_handler(context, turn_state, auth_handler_id) - if sec_route: - turn_state.delete_value(self.SIGN_IN_STATE_KEY) - else: - if self._sign_in_failed_handler: - logger.warning( - "Sign-in failed, calling sign-in failed handler", - stack_info=True, - ) - await self._sign_in_failed_handler( - context, turn_state, auth_handler_id - ) - - await turn_state.save(context) - return token_response - - def resolver_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: + if self.__sign_in_failure_handler: + sign_in_context.on_failure( + lambda err: self.__sign_in_failure_handler( + context, turn_state, sign_in_context.handler.id, err + ) + ) + + async for activity in auth_handler.begin_or_continue_flow(): + pass + + token_response = await sign_in_context.get_token() + return BeginOrContinueFlowResponse(token_response, sign_in_context.handler) + + def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: """ Resolves the auth handler to use based on the provided ID. @@ -376,8 +283,21 @@ def resolver_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler return self._auth_handlers[auth_handler_id] # Return the first handler if no ID specified - first_key = next(iter(self._auth_handlers)) - return self._auth_handlers[first_key] + return next(iter(self._auth_handlers.values)) + + async def __sign_out( + self, + context: TurnContext, + state: TurnState, + auth_handler_ids: Iterable[str] = None, + ) -> None: + flow_storage_client = FlowStorageClient(context, self.__storage) + for auth_handler_id in auth_handler_ids: + auth_handler = self.resolver_handler(auth_handler_id) + flow_state = flow_storage_client.read(auth_handler.flow_id) + if flow_state: + logger.info(f"Signing out from handler: {auth_handler_id}") + await auth_handler.flow.sign_out(context) async def sign_out( self, @@ -394,18 +314,10 @@ async def sign_out( state: The state object for the current turn. auth_handler_id: Optional ID of the auth handler to use for sign out. """ - if auth_handler_id is None: - # Sign out from all handlers - for handler_key, auth_handler in self._auth_handlers.items(): - if auth_handler.flow: - logger.info(f"Signing out from handler: {handler_key}") - await auth_handler.flow.sign_out(context) + if auth_handler_id: + self.__sign_out(context, state, [auth_handler_id]) else: - # Sign out from specific handler - auth_handler = self.resolver_handler(auth_handler_id) - if auth_handler.flow: - logger.info(f"Signing out from handler: {auth_handler_id}") - await auth_handler.flow.sign_out(context) + self.__sign_out(context, state, self._auth_handlers.keys()) def on_sign_in_success( self, @@ -417,7 +329,7 @@ def on_sign_in_success( Args: handler: The handler function to call on successful sign-in. """ - self._sign_in_handler = handler + self.__sign_in_handler = handler def on_sign_in_failure( self, @@ -428,4 +340,4 @@ def on_sign_in_failure( Args: handler: The handler function to call on sign-in failure. """ - self._sign_in_failed_handler = handler + self.__sign_in_failure = handler \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py new file mode 100644 index 00000000..3d16c7a7 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -0,0 +1,57 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from microsoft.agents.activity import Activity +from microsoft.agents.hosting.core import ( + TurnContext, + Storage, + StoreItem +) + +# robrandao: TODO -> context.activity.from_property +class FlowStorageClient: + """ + Wrapper around storage that manages sign-in state specific to each user and channel. + + Uses the activity's channel_id and from.id to create a key prefix for storage operations. + """ + + def __init__( + self, + context: TurnContext, + storage: Storage + ): + + if ( + not context.activity + or not context.activity.channel_id + or not context.activity.from_property + or not context.activity.from_property.id + ): + + raise ValueError("context.activity -> channel_id and from.id must be set.") + + channel_id = context.activity.channel_id + user_id = context.activity.from_property.id + + self.__base_key = f"auth/{channel_id}/{user_id}" + self.__storage = storage + + def __key(self, id: str) -> str: + """Creates a storage key for a specific sign-in handler.""" + return f"{self.__base_key}/${id}" + + async def read(self, auth_handler_id: str) -> Optional[FlowState]: + key: str = self.__key(auth_handler_id) + data = await self.__storage.read([key], FlowState) + return data.get(key) # robrandao: TODO -> verify contract + + async def write(self, value: FlowState) -> None: + key: str = self.__key(value.id) + await self.__storage.write({key: value}) + + async def delete(self, auth_handler_id: str) -> None: + key: str = self.__key(auth_handler_id) + await self.__storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py new file mode 100644 index 00000000..8f919a14 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py @@ -0,0 +1,57 @@ +import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel +from pydantic.types import PositiveInt + +from microsoft.agents.activity import Activity +from microsoft.agents.hosting.core import StoreItem + + +class FlowStateTag(Enum): + BEGIN = "begin" + CONTINUE = "continue" + NOT_STARTED = "not_started" + FAILURE = "failure" + COMPLETE = "complete" + +class FlowState(BaseModel, StoreItem): + + flow_started: bool = False + user_token: str = "" + flow_expires: float = 0 + abs_oauth_connection_name: Optional[str] = None + continuation_activity: Optional[Activity] = None + attempts_remaining: PositiveInt = 3 + tag: FlowStateTag = FlowStateTag.INACTIVE + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + + if self.is_expired() or self.reached_max_retries(): + self.tag = FlowStateTag.FAILURE + + def store_item_to_json(self) -> dict: + return self.model_dump() + + @staticmethod + def from_json_to_store_item(json_data: dict) -> "FlowState": + return FlowState.model_validate(json_data) + + def is_expired(self) -> bool: + return datetime.now().timestamp() >= self.flow_expires + + def reached_max_retries(self) -> bool: + return self.attempts_remaining <= 0 + + # @staticmethod + # def generate_begin_state(): + # pass + +class FlowResponse(BaseModel): + + flow_data: FlowData + in_flow_activity: Activity + token_response: Optional[TokenResponse] = None diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py deleted file mode 100644 index ae64443e..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/sign_in_storage.py +++ /dev/null @@ -1,72 +0,0 @@ -from enum import Enum -from typing import Optional - -from pydantic import BaseModel - -from microsoft.agents.activity import Activity -from microsoft.agents.hosting.core import StoreItem - -class FlowState: - pass - -class SignInHandlerStateStatus(Enum): - NOT_STARTED = "not_started" - CONTINUE = "in_progress" - COMPLETED = "completed" - FAILURE = "failure" - - -class SignInHandlerState(BaseModel, StoreItem): - id: str - status: SignInHandlerStateStatus - continuation_activity: Activity - - def store_item_to_json(self): # todo - return super().store_item_to_json() - - @staticmethod - def from_json_to_store_item(json_data): # todo - return super().from_json_to_store_item(json_data) - - -class SignInStorage: - - def __init__(self, context: TurnContext, storage: Storage, handlers: Optional[list[AuthHandler]] = None): - - if (not context.activity or - not context.activity.channel_id or - not context.activity.from_property or - not context.activity.from_property.id): - - raise ValueError("context.activity -> channel_id and from.id must be set.") - - channel_id = context.activity.channel_id - user_id = context.activity.from_property.id - - self.__base_key = f"auth/{channel_id}/{user_id}" - self.__handlers = handlers or [] - self.__handler_keys = list(map(create_key, self.__handlers)) - self.__storage = storage - - def create_key(self, id: str) -> str: - if not self.__base_key: - raise AttributeError # robrandao: TODO - return f"{self.__base_key}/${id}" - - async def active(self) -> Optional[SignInHandlerState]: - # batched reads would make this more efficient - for handler_key in self.__handler_keys: - state = (await self.__storage.read([handler_key], SignInHandlerState)).get(handler_key) - if state and state.status == SignInHandlerStateStatus.IN_PROGRESS: - return state - - async def get(self, id: str) -> Optional[SignInHandlerState]: - key = self.create_key(id) - data = await self.__storage.read([key], SignInHandlerState) - return data.get(key) # robrandao: TODO -> verify contract - - async def set(self, value: SignInHandlerState) -> None: - await self.__storage.write({ self.create_key(value.id): value}) - - async def delete(self, id: str) -> None: - await self.__storage.delete([self.create_key(id)]) \ No newline at end of file From 2bab02c45a0b9fb7489408ddbddec95fc7c2d534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Thu, 14 Aug 2025 16:48:13 -0700 Subject: [PATCH 04/32] Adding code to connect to AgentApp --- .../hosting/core/app/oauth/auth_flow.py | 17 +---- .../hosting/core/app/oauth/authorization.py | 47 ++++++------ .../core/app/oauth/flow_storage_client.py | 10 +-- .../agents/hosting/core/app/oauth/models.py | 8 +- .../agents/hosting/core/oauth_flow.py | 75 +++++++++++++++---- 5 files changed, 90 insertions(+), 67 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py index 27cb7e70..3e84a1e2 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -154,7 +154,6 @@ async def __continue_from_invoke_token_exchange(self, context: TurnContext) -> N body=token_exchange_request ) - async def continue_flow(self, context: TurnContext) -> FlowResponse: if self.flow_state.is_expired() or self.flow_state.reached_max_retries(): self.flow_state.flow_state_tag = FlowStateTag.FAILURE @@ -179,19 +178,9 @@ async def continue_flow(self, context: TurnContext) -> FlowResponse: pass - async def continue_flow(self, context: TurnContext) -> FlowResponse: - pass - async def begin_or_continue_flow(self, context: TurnContext) -> FlowResponse: - - tag = self.flow_state.flow_state_tag - if tag == FlowStateTag.CONTINUE: - self.continue_flow(context) + if self.flow_state.is_active(): + return await self.continue_flow(context) else: - self.begin_flow(context) - - if tag == FlowStateTag.BEGIN: - pass - elif tag == FlowStateTag.CONTINUE: - pass \ No newline at end of file + return await self.begin_flow(context) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index b9eff281..454963b2 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -18,13 +18,14 @@ from ...turn_context import TurnContext from ...app.state.turn_state import TurnState -from ...oauth_flow import OAuthFlow, FlowState +from ...oauth_flow import OAuthFlow from ...state.user_state import UserState from .sign_in_context import SignInContext from .auth_handler import AuthHandler, AuthorizationHandlers from .sign_in_state import SignInState -from .storage import AuthStateStorage +from .storage import FlowStorageClient +from .models import FlowResponse, FlowState, FlowStateTag logger = logging.getLogger(__name__) @@ -220,13 +221,22 @@ async def _handle_obo( scopes=scopes, # Expiration can be set based on the token provider's response ) + async def get_active_flow_state(self, context: TurnContext, turn_state: TurnState) -> Optional[FlowResponse]: + flow_storage_client = FlowStorageClient(context, self.__storage) + + for auth_handler_id in self._auth_handlers.keys(): + flow_state = await flow_storage_client.read(auth_handler_id) + if flow_state.is_active(): + return flow_state + return None + async def begin_or_continue_flow( self, context: TurnContext, turn_state: TurnState, auth_handler_id: str, sec_route: bool = True, - ) -> AuthFlowResponse: + ) -> FlowResponse: """ Begins or continues an OAuth flow. @@ -241,30 +251,17 @@ async def begin_or_continue_flow( # robrandao: TODO -> is_started_from_route and sec_route flow_storage_client = FlowStorageClient(context, self.__storage) - flow = self.__create_flow(context, auth_handler_id) - - - sign_in_context: SignInContext = self.__create_sign_in_context( - context, auth_handler_id, 42 - ) - if self.__sign_in_success_handler: - sign_in_context.on_success( - lambda: self.__sign_in_success_handler( - context, turn_state, sign_in_context.handler.id - ) - ) - if self.__sign_in_failure_handler: - sign_in_context.on_failure( - lambda err: self.__sign_in_failure_handler( - context, turn_state, sign_in_context.handler.id, err - ) - ) + flow = OAuthFlow(context, auth_handler_id) - async for activity in auth_handler.begin_or_continue_flow(): - pass + flow_response: FlowResponse = flow.begin_or_continue_flow(context) + flow_state: FlowState = flow_response.flow_state - token_response = await sign_in_context.get_token() - return BeginOrContinueFlowResponse(token_response, sign_in_context.handler) + if flow_state.tag == FlowStateTag.COMPLETE: + self.__on_sign_in_success_handler(context, turn_state, flow_state.handler.id) + elif flow_state.tag == FlowStateTag.FAILURE: + self.__on_sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) + + return flow_response def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: """ diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py index 3d16c7a7..07cf8adb 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -1,14 +1,8 @@ -from enum import Enum from typing import Optional -from pydantic import BaseModel +from microsoft.agents.hosting.core import TurnContext, Storage -from microsoft.agents.activity import Activity -from microsoft.agents.hosting.core import ( - TurnContext, - Storage, - StoreItem -) +from .models import FlowState # robrandao: TODO -> context.activity.from_property class FlowStorageClient: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py index 8f919a14..77aa8f70 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py @@ -8,7 +8,6 @@ from microsoft.agents.activity import Activity from microsoft.agents.hosting.core import StoreItem - class FlowStateTag(Enum): BEGIN = "begin" CONTINUE = "continue" @@ -46,10 +45,9 @@ def is_expired(self) -> bool: def reached_max_retries(self) -> bool: return self.attempts_remaining <= 0 - # @staticmethod - # def generate_begin_state(): - # pass - + def is_active(self) -> bool: + return not self.is_expired() and not self.reached_max_retries() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] + class FlowResponse(BaseModel): flow_data: FlowData diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py index c9deeda7..29c1b247 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py @@ -3,6 +3,9 @@ from __future__ import annotations +import logging + +from enum import Enum from datetime import datetime from typing import Dict, Optional @@ -21,18 +24,29 @@ TurnContextProtocol as TurnContext, ) from microsoft.agents.hosting.core.storage import StoreItem, Storage -from pydantic import BaseModel +from pydantic import BaseModel, PositiveInt from .message_factory import MessageFactory from .card_factory import CardFactory +logger = logging.getLogger(__name__) + +# class FlowStatus(Enum): +# IN_ACTIVE = "not_started" +# IN_PROGRESS = "in_progress" +# COMPLETED = "completed" +# ERROR = "error" +# EXPIRED = "expired" + class FlowState(StoreItem, BaseModel): flow_started: bool = False + # flow_status: FLOW_STATUS user_token: str = "" flow_expires: float = 0 abs_oauth_connection_name: Optional[str] = None continuation_activity: Optional[Activity] = None + attempts_remaining: PositiveInt = 1 def store_item_to_json(self) -> dict: return self.model_dump() @@ -63,6 +77,9 @@ def __init__( abs_oauth_connection_name: The OAuth connection name. user_token_client: Optional user token client. messages_configuration: Optional messages configuration for backward compatibility. + + Kwargs: + flow_total_tries: The total number of auth attempts made by the user during a single flow """ if not abs_oauth_connection_name: raise ValueError( @@ -81,6 +98,8 @@ def __init__( self._storage = storage self.flow_state = None + self.initial_attempts_remaining = kwargs.get("initial_attempts_remaining", 3) + async def get_user_token(self, context: TurnContext) -> TokenResponse: """ Retrieves the user token from the user token service. @@ -108,6 +127,20 @@ async def get_user_token(self, context: TurnContext) -> TokenResponse: channel_id=context.activity.channel_id, ) + async def reset_to_initial_flow_state(self, context: TurnContext) -> None: + self.flow_state.flow_started = True + self.flow_state.flow_expires = datetime.now().timestamp() + 30000 + self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name + self.flow_state.attempts_remaining = self.initial_attempts_remaining + await self._save_flow_state(context) + + async def reset_to_finished_flow_state(self, context: TurnContext) -> None: + self.flow_state.flow_started = False + self.flow_state.flow_expires = 0 + self.flow_state.attempts_remaining = 0 + self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name + await self._save_flow_state(context) + async def begin_flow(self, context: TurnContext) -> TokenResponse: """ Begins the OAuth flow. @@ -138,6 +171,7 @@ async def begin_flow(self, context: TurnContext) -> TokenResponse: # Already have token, return it self.flow_state.flow_started = False self.flow_state.flow_expires = 0 + self.flow_state.attempts_remaining = 0 self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name await self._save_flow_state(context) return user_token @@ -180,10 +214,7 @@ async def begin_flow(self, context: TurnContext) -> TokenResponse: await context.send_activity(MessageFactory.attachment(o_card)) # Update flow state - self.flow_state.flow_started = True - self.flow_state.flow_expires = datetime.now().timestamp() + 30000 - self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name - await self._save_flow_state(context) + await self.reset_to_initial_flow_state(context) # Return in-progress response return TokenResponse() @@ -200,11 +231,14 @@ async def continue_flow(self, context: TurnContext) -> TokenResponse: """ await self._initialize_token_client(context) - if ( - self.flow_state - and self.flow_state.flow_expires != 0 - and datetime.now().timestamp() > self.flow_state.flow_expires + if self.flow_state and ( + ( + self.flow_state.flow_expires != 0 + and datetime.now().timestamp() > self.flow_state.flow_expires + ) + or (self.flow_state.attempts_remaining <= 0) ): + # self.flow_state = False await context.send_activity( MessageFactory.text( self.messages_configuration.get( @@ -219,6 +253,11 @@ async def continue_flow(self, context: TurnContext) -> TokenResponse: # Handle message type activities (typically when the user enters a code) if cont_flow_activity.type == ActivityTypes.message: + self.flow_state.attempts_remaining -= 1 + logger.info( + f"Attempts remaining in this flow: {self.flow_state.attempts_remaining}" + ) + magic_code = cont_flow_activity.text # Validate magic code format (6 digits) @@ -231,12 +270,7 @@ async def continue_flow(self, context: TurnContext) -> TokenResponse: ) if result and result.token: - self.flow_state.flow_started = False - self.flow_state.flow_expires = 0 - self.flow_state.abs_oauth_connection_name = ( - self.abs_oauth_connection_name - ) - await self._save_flow_state(context) + await self.reset_to_finished_flow_state(context) return result else: await context.send_activity( @@ -259,6 +293,11 @@ async def continue_flow(self, context: TurnContext) -> TokenResponse: cont_flow_activity.type == ActivityTypes.invoke and cont_flow_activity.name == "signin/verifyState" ): + self.flow_state.attempts_remaining -= 1 + logger.info( + f"Attempts remaining in this flow: {self.flow_state.attempts_remaining}" + ) + token_verify_state = cont_flow_activity.value magic_code = token_verify_state.get("state") @@ -283,6 +322,11 @@ async def continue_flow(self, context: TurnContext) -> TokenResponse: cont_flow_activity.type == ActivityTypes.invoke and cont_flow_activity.name == "signin/tokenExchange" ): + self.flow_state.attempts_remaining -= 1 + logger.info( + f"Attempts remaining in this flow: {self.flow_state.attempts_remaining}" + ) + token_exchange_request = cont_flow_activity.value # Dedupe checks to prevent duplicate processing @@ -329,6 +373,7 @@ async def sign_out(self, context: TurnContext) -> None: if self.flow_state: self.flow_state.flow_expires = 0 + self.flow_state.attempts_remaining = 0 await self._save_flow_state(context) async def _get_flow_state(self, context: TurnContext) -> FlowState: From 4bd5f45af5ec4a0ebaf3f22695b2df502e1759ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Mon, 18 Aug 2025 15:12:24 -0700 Subject: [PATCH 05/32] Generating new unit tests --- .../core/app/oauth/authorization_test.py | 0 .../app/oauth/flow_storage_client_test.py | 162 ++++++++++++++++++ .../hosting/core/app/oauth/models_test.py | 106 ++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization_test.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client_test.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models_test.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization_test.py new file mode 100644 index 00000000..e69de29b diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client_test.py new file mode 100644 index 00000000..e169e753 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client_test.py @@ -0,0 +1,162 @@ +from microsoft.agents.hosting.core import MemoryStorage + +import pytest + +from microsoft.agents.hosting.core import ( + Storage, + FlowStorageClient, + MockStoreItem, + FlowState +) + +class TestFlowStorageClient: + + @pytest.fixture + def turn_context(self, mocker): + context = mocker.Mock() + context.activity.channel_id = "__channel_id" + context.activity.from_property.id = "__user_id" + return context + + @pytest.fixture + def client(self, turn_context, storage): + return FlowStorageClient(turn_context, storage) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, turn_context, storage, channel_id, from_property_id", + [ + ("mocker", "turn_context", "storage", "channel_id", "from_property_id"), + ("mocker", "turn_context", "storage", "teams_id", "Bob"), + ("mocker", "turn_context", "storage", "channel", "Alice"), + ], + indirect=["mocker", "turn_context", "storage"] + ) + async def test_init_base_key(self, mocker, turn_context, storage, channel_id, from_property_id): + context = mocker.Mock() + context.activity.channel_id = channel_id + context.activity.from_property.id = from_property_id + client = FlowStorageClient(context, storage) + assert client.base_key == f"auth/{channel_id}/{from_property_id}/" + + async def test_init_fails_without_from_id(self, mocker, storage): + with pytest.raises(ValueError): + context = mocker.Mock() + context.activity.channel_id = "channel_id" + FlowStorageClient(context, storage) + + async def test_init_fails_without_channel_id(self, mocker, storage): + with pytest.raises(ValueError): + context = mocker.Mock() + context.activity.from_property.id = "from_id" + FlowStorageClient(context, storage) + + @pytest.mark.parametrize( + "client, auth_handler_id, expected", + [ + (client, "handler", "auth/__channel_id/__user_id/handler"), + (client, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ] + ) + def test_key(self, client, auth_handler_id, expected): + assert client.key(auth_handler_id) == expected + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, turn_context storage, client, auth_handler_id", + [ + (mocker, turn_context, storage, client, "handler"), + (mocker, turn_context, storage, client, "auth_handler"), + ] + ) + async def test_read(self, mocker, turn_context, storage, client, auth_handler_id): + storage = mocker.AsyncMock() + storage.read.return_value = sentinel.read_response + client = FlowStorageClient(turn_context, storage) + res = await client.read(auth_handler_id) + assert res == storage.read.return_value + assert storage.read.called_once_with([f"auth/__channel_id/__user_id/{auth_handler_id}"], FlowState) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, turn_context storage, client, auth_handler_id", + [ + (mocker, turn_context, storage, client, "handler", "auth/__channel_id/__user_id/handler"), + (mocker, turn_context, storage, client, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ] + ) + async def test_write(self, mocker, turn_context, storage, client, auth_handler_id, key, flow_state): + storage = mocker.AsyncMock() + storage.write.return_value = None + client = FlowStorageClient(turn_context, storage) + flow_state = mocker.Mock(spec=FlowState) + flow_state.id = auth_handler_id + await client.write(flow_state) + assert storage.write.called_once_with({ key: flow_state }) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, turn_context storage, client, auth_handler_id", + [ + (mocker, turn_context, storage, client, "handler", "auth/__channel_id/__user_id/handler"), + (mocker, turn_context, storage, client, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ] + ) + async def test_delete(self, mocker, turn_context, storage, client, auth_handler_id): + storage = mocker.AsyncMock() + storage.write.return_value = None + client = FlowStorageClient(turn_context, storage) + await client.delete(auth_handler_id) + assert storage.write.called_once_with([auth_handler_id]) + + async def test_integration_with_memory_storage(self, turn_context): + + flow_state_alpha = FlowState(flow_id="handler", flow_started=True) + flow_state_beta = FlowState(flow_id="auth_handler", flow_started=True, user_token="token") + + storage = MemoryStorage({ + "some_data": MockStoreItem({"value": "test"}), + "auth/__channel_id/__user_id/handler": flow_state_alpha, + "auth/__channel_id/__user_id/auth_handler": flow_state_beta, + }) + baseline = MemoryStorage({ + "some_data": MockStoreItem({"value": "test"}), + "auth/__channel_id/__user_id/handler": flow_state_alpha, + "auth/__channel_id/__user_id/auth_handler": flow_state_beta, + }) + + # helpers + async def read_check(*args, **kwargs): + res_storage = await storage.read(*args, **kwargs) + res_baseline = await baseline.read(*args, **kwargs) + assert res_storage == res_baseline + + async def write_both(*args, **kwargs): + await storage.write(*args, **kwargs) + await baseline.write(*args, **kwargs) + + async def delete_both(*args, **kwargs): + await storage.delete(*args, **kwargs) + await baseline.delete(*args, **kwargs) + + client = FlowStorageClient(turn_context, storage) + + new_flow_state_alpha = FlowState(flow_id="handler") + flow_state_chi = FlowState(flow_id="chi") + + await client.write(new_flow_state_alpha) + await client.write(flow_state_chi) + baseline.write({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) + baseline.write({"auth/__channel_id/__user_id/chi": flow_state_chi.copy()}) + + write_both({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) + write_both({"auth/__channel_id/__user_id/auth_handler": flow_state_beta.copy()}) + write_both({"other_data": MockStoreItem({"value": "more"}).copy()}) + + delete_both(["some_data"]) + + assert read_check(["auth/__channel_id/__user_id/handler"], FlowState) + assert read_check(["auth/__channel_id/__user_id/auth_handler"], FlowState) + assert read_check(["auth/__channel_id/__user_id/chi"], FlowState) + assert read_check(["other_data"], MockStoreItem) + assert read_check(["some_data"], MockStoreItem) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models_test.py new file mode 100644 index 00000000..05d83780 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models_test.py @@ -0,0 +1,106 @@ +import datetime + +import pytest + +from microsoft.agents.hosting.core.app.oauth.models import FlowState, FlowStateTag + +class TestFlowState: + + def test_refresh_to_failure_expired(self): + flow_state = FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expires_at=datetime.now().timestamp() + ) + flow_state.refresh() + assert flow_state.tag == FlowStateTag.FAILURE + + def test_refresh_to_failure_max_attempts(self): + flow_state = FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + ) + flow_state.refresh() + assert flow_state.tag == FlowStateTag.FAILURE + + def test_refresh_unchanged_continue(self): + flow_state = FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expires_at=datetime.now().timestamp() + 10000 + ) + prev_tag = flow_state.tag + flow_state.refresh() + assert flow_state.tag == prev_tag + + def test_refresh_unchanged_begin(self): + flow_state = FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=10, + expires_at=datetime.now().timestamp() + 30000 + ) + prev_tag = flow_state.tag + flow_state.refresh() + assert flow_state.tag == prev_tag + + @pytest.mark.parametrize( + "flow_state, expected", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + True), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), + False), + ] + ) + def test_is_expired(self, flow_state, expected): + assert flow_state.is_expired() == expected + + @pytest.mark.parametrize( + "flow_state, expected", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + False), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + True), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()), + True), + ] + ) + def test_reached_max_attempts(flow_state, expected): + assert flow_state.reached_max_attempts() == expected + + @pytest.mark.parametrize( + "flow_state, expected", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + False), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + False), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), + False), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=2, expires_at=datetime.now().timestamp()+1000), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=0, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + True) + ] + ) + def test_is_active(flow_state, expected): + assert flow_state.is_active() == expected \ No newline at end of file From eb62f27fa97a286444ca7998d3c697b84689de08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Tue, 19 Aug 2025 00:38:44 -0700 Subject: [PATCH 06/32] Filling in test cases for OAuthFlow and models --- .../agents/hosting/core/app/oauth/conftest.py | 3 + .../agents/hosting/core/app/oauth/models.py | 18 +- .../app/oauth/{auth_flow.py => oauth_flow.py} | 141 ++++--- .../hosting/core/app/oauth/oauth_flow_test.py | 397 ++++++++++++++++++ ...rage_client.py => oauth_storage_client.py} | 33 +- ...t_test.py => oauth_storage_client_test.py} | 0 6 files changed, 519 insertions(+), 73 deletions(-) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{auth_flow.py => oauth_flow.py} (54%) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow_test.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{flow_storage_client.py => oauth_storage_client.py} (56%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{flow_storage_client_test.py => oauth_storage_client_test.py} (100%) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py new file mode 100644 index 00000000..aa08436e --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py @@ -0,0 +1,3 @@ +def turn_context(): + + context = TurnContext() \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py index 77aa8f70..d1a1ac2d 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py @@ -15,20 +15,23 @@ class FlowStateTag(Enum): FAILURE = "failure" COMPLETE = "complete" +class FlowErrorTag(Enum): + NONE = "none" + MAGIC_FORMAT = "magic_format" + MAGIC_CODE = "magic_code" + class FlowState(BaseModel, StoreItem): + flow_id: NonEmptyString flow_started: bool = False user_token: str = "" - flow_expires: float = 0 + expires: float = 0 abs_oauth_connection_name: Optional[str] = None continuation_activity: Optional[Activity] = None attempts_remaining: PositiveInt = 3 tag: FlowStateTag = FlowStateTag.INACTIVE - def __init__(self, *args, **kwargs): - - super().__init__(*args, **kwargs) - + def refresh(self) -> None: if self.is_expired() or self.reached_max_retries(): self.tag = FlowStateTag.FAILURE @@ -42,7 +45,7 @@ def from_json_to_store_item(json_data: dict) -> "FlowState": def is_expired(self) -> bool: return datetime.now().timestamp() >= self.flow_expires - def reached_max_retries(self) -> bool: + def reached_max_attempts(self) -> bool: return self.attempts_remaining <= 0 def is_active(self) -> bool: @@ -51,5 +54,6 @@ def is_active(self) -> bool: class FlowResponse(BaseModel): flow_data: FlowData - in_flow_activity: Activity + flow_error_tag: FlowErrorTag = FlowErrorTag.NONE token_response: Optional[TokenResponse] = None + sign_in_resource: Optional[SignInResource] = None diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow.py similarity index 54% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow.py index 3e84a1e2..4e0a403c 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow.py @@ -11,14 +11,9 @@ from microsoft.agents.hosting.core.connector.client import UserTokenClient from microsoft.agents.activity import ( - ActionTypes, ActivityTypes, - CardAction, - Attachment, - OAuthCard, TokenExchangeState, TokenResponse, - Activity, ) from microsoft.agents.activity import ( TurnContextProtocol as TurnContext, @@ -28,36 +23,58 @@ from .message_factory import MessageFactory from .card_factory import CardFactory -from .models import FlowResponse, FlowState, FlowStateTag +from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag logger = logging.getLogger(__name__) -class AuthFlow: +class OAuthFlow: """ Manages the OAuth flow. + + This class is responsible for managing the entire OAuth flow, including + obtaining user tokens, signing out users, and handling token exchanges. + + Contract with other classes (usage of other classes is enforced in unit tests): + TurnContext.activity.channel_id + TurnContext.activity.from_property.id + + UserTokenClient: user_token.get_token(), user_token.sign_out() """ def __init__( self, + abs_oauth_connection_name: str, + user_token_client: UserTokenClient, flow_state: FlowState = None, - abs_oauth_connection_name: str = None, - user_token_client: Optional[UserTokenClient] = None, **kwargs ): + """ + Arguments: + abs_oauth_connection_name: + + user_token_client: + + flow_state: + """ if not abs_oauth_connection_name: raise ValueError( "OAuthFlow.__init__: abs_oauth_connection_name required." ) + if not user_token_client: + raise ValueError( + "OAuthFlow.__init__: user_token_client required." + ) - self.flow_state = flow_state or FlowState() + flow_state = flow_state or FlowState() # robrandao: TODO + self.flow_state = flow_state.copy() self.__abs_oauth_connection_name = abs_oauth_connection_name self.__user_token_client = user_token_client - async def __initialize_token_client(self, context: TurnContext) -> None: - # robrandao: TODO is this safe - # use cached value later - self.__user_token_client = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) + # async def __initialize_token_client(self, context: TurnContext) -> None: + # # robrandao: TODO is this safe + # # use cached value later + # self.__user_token_client = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) async def __get_ids_or_raise(self, context: TurnContext) -> TokenResponse: if ( @@ -71,9 +88,8 @@ async def __get_ids_or_raise(self, context: TurnContext) -> TokenResponse: async def __get_user_token(self, context: TurnContext, magic_code=None) -> TokenResponse: channel_id, from_id = self.__get_ids_or_raise(context) - await self.__initialize_token_client(context) - return await self.user_token_client.user_token.get_token( + return await self.__user_token_client.user_token.get_token( user_id=from_id, connection_name=self.__abs_oauth_connection_name, channel_id=channel_id, @@ -85,28 +101,33 @@ async def get_user_token(self, context: TurnContext) -> TokenResponse: async def sign_out(self, context: TurnContext) -> None: channel_id, from_id = self.__get_ids_or_raise(context) - await self.__initialize_token_client(context) - return await self.__user_token_client.user_token.get_token( + return await self.__user_token_client.user_token.sign_out( user_id=from_id, connection_name=self.__abs_oauth_connection_name, channel_id=channel_id ) - async def __use_attempt(self, context: TurnContext) -> None: + async def __use_attempt(self) -> None: if self.flow_state.attempts_remaining <= 0: self.flow_state.flow_state_tag = FlowStateTag.FAILURE - - async def __failed_attempt(self, context: TurnContext) -> None: - pass async def begin_flow(self, context: TurnContext) -> FlowResponse: + self.flow_state = FlowState( + id=self.__abs_oauth_connection_name, + channel_id=context.activity.channel_id, + user_id=context.activity.from_property.id + ) + # init flow state token_response = self.get_user_token(context) - if token_response and token_response.token: - pass + if token_response: + return FlowResponse( + flow_state=self.flow_state, + token_response=token_response + ) token_exchange_state = TokenExchangeState( connection_name=self.__abs_oauth_connection_name, @@ -115,71 +136,73 @@ async def begin_flow(self, context: TurnContext) -> FlowResponse: ms_app_id=context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims["aud"] # robrandao: TODO ) - sign_in_resource = await self.__user_token_client.agent_sign_in.get_sign_in_resource(state=token_exchange_sate.get_encoded_state()) + sign_in_resource = await self.__user_token_client.agent_sign_in.get_sign_in_resource( + state=token_exchange_state.get_encoded_state()) - return FlowResponse(flow_state=self.flow_state) + return FlowResponse(flow_state=self.flow_state, sign_in_resource=sign_in_resource) - async def __continue_from_message(self, context: TurnContext) -> None: + async def __continue_from_message(self, context: TurnContext) -> tuple[TokenResponse, FlowErrorTag]: - magic_code = activity.text + magic_code: str = context.activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: - result = self.__get_user_token(context, magic_code) + token_response: TokenResponse = await self.__get_user_token(context, magic_code) - if result and result.token: - return result + if token_response: + return token_response, FlowErrorTag.NONE else: - return InvalidCodeError + return token_response, FlowErrorTag.MAGIC_CODE else: - return InvalidCodeFormatError + return TokenResponse(), FlowErrorTag.MAGIC_FORMAT - async def __continue_from_invoke_verify_state(self, context: TurnContext) -> None: - token_verify_sate = context.activity.value - magic_code = token_verify_state.get("state") - result = self.__get_user_token(context, magic_code) - if result and result.token: - pass - return None + async def __continue_from_invoke_verify_state(self, context: TurnContext) -> TokenResponse: + token_verify_state = context.activity.value + magic_code: str = token_verify_state.get("state") + token_response: TokenResponse = await self.__get_user_token(context, magic_code) + return token_response - async def __continue_from_invoke_token_exchange(self, context: TurnContext) -> None: - await self.__initialize_token_client(context) + async def __continue_from_invoke_token_exchange(self, context: TurnContext) -> TokenResponse: channel_id, from_id = self.__get_ids_or_raise(context) - token_exchange_request = context.activity.value - token_exchange_id = token_exchange_request.get("id") - - return await self.__user_token_client.user_token.exchange_token( - user_id=context.activity.from_property.id, + token_response = await self.__user_token_client.user_token.exchange_token( + user_id=from_id, connection_name=self.__abs_oauth_connection_name, channel_id=channel_id, body=token_exchange_request ) + return token_response, FlowErrorTag.NONE async def continue_flow(self, context: TurnContext) -> FlowResponse: - if self.flow_state.is_expired() or self.flow_state.reached_max_retries(): + logger.debug("Continuing auth flow...") + + if not self.flow_state.is_active(): self.flow_state.flow_state_tag = FlowStateTag.FAILURE return FlowResponse(flow_state=self.flow_state) continue_flow_activity = context.activity + flow_error_tag = FlowErrorTag.NONE if continue_flow_activity.type == ActivityTypes.message: - token_response, flow_error = continue_flow_from_message() + token_response, flow_error_tag = await self.__continue_from_message(context) elif continue_flow_activity.type == ActivityTypes.invoke and continue_flow_activity.name == "signin/verifyState": - token_response, flow_error = continue_flow_from_invoke_verify_state() - elif continue_flow_activity.type == ActivityTypes.invoke and continue_flow_activity.name == "signin/tokenExchange": - token_response, flow_error = continue_flow_from_invoke_token_exchange() + token_response = await self.__continue_from_invoke_verify_state(context) + elif continue_flow_activity.flow_error_tag == ActivityTypes.invoke and continue_flow_activity.name == "signin/tokenExchange": + token_response = await self.__continue_from_invoke_token_exchange(context) else: pass - if flow_error != FlowError.NONE and token_response and token_response.token: - pass - elif flow_error == FlowError.NONE: - flow_error = - - pass + if not token_response and flow_error_tag == FlowErrorTag.NONE: + flow_error_tag = FlowErrorTag.UNKNOWN + if flow_error_tag != FlowErrorTag.NONE: + self.__use_attempt() - async def begin_or_continue_flow(self, context: TurnContext) -> FlowResponse: + return FlowResponse( + flow_state=self.flow_state, + flow_error_tag=flow_error_tag, + token_response=token_response + ) + async def begin_or_continue_flow(self, context: TurnContext) -> FlowResponse: if self.flow_state.is_active(): return await self.continue_flow(context) else: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow_test.py new file mode 100644 index 00000000..f54e67ee --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow_test.py @@ -0,0 +1,397 @@ +import pytest + +from microsoft.agents.activity import ( + ActivityTypes, + TokenResponse +) +from microsoft.agents.hosting.core import AuthFlow + +from microsoft.agents.hosting.core.app.oauth.models import ( + FlowErrorTag, + FlowState, + FlowStateTag, + FlowResponse +) + +class TestAuthFlow: + + @pytest.fixture + def turn_context(self, mocker): + context = mocker.Mock() + context.activity.channel_id = "__channel_id" + context.activity.from_property.id = "__user_id" + return context + + def test_init_no_state(self): + flow = AuthFlow() + assert flow.flow_state == FlowState() + + def test_init_with_state(self): + flow_state = FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expires_at=datetime.now().timestamp() + 10000 + ) + flow = AuthFlow(flow_state=flow_state) + assert flow.flow_state == flow_state + + @pytest.mark.asyncio + async def test_get_user_token(self, turn_context): + # mock + user_token_client = pytest.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value="test_token") + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + ) + token = await flow.get_user_token(turn_context) + + # verify + assert token == "test_token" + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + magic_code=None + ) + + @pytest.mark.asyncio + async def test_sign_out(self, turn_context): + # mock + user_token_client = pytest.Mock() + user_token_client.user_token.sign_out = pytest.AsyncMock() + + # test + flow = AuthFlow( + abs_oauth_connection_name="connection", + user_token_client=user_token_client, + ) + await flow.sign_out(turn_context) + + # verify + assert user_token_client.user_token.sign_out.called_once_with( + user_id="__user_id", + connection_name="connection", + channel_id="__channel_id", + magic_code=None + ) + + @pytest.mark.asyncio + async def test_begin_flow_easy_case(self): + # mock + user_token_client = pytest.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse(token="test_token")) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + ) + response = await flow.begin_flow(turn_context) + + # verify flow_state + flow_state = flow.flow_state + assert flow_state.tag == FlowStateTag.COMPLETE + assert flow_state.token == "test_token" + assert flow_state.flow_started is False # robrandao: TODO? + + # verify FlowResponse + assert response.flow_state == flow_state + assert response.sign_in_resource is None # No sign-in resource in this case + assert response.flow_error_tag == FlowErrorTag.NONE + assert response.token_response == "test_token" + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + # magic_code=None is an implementation detail, and ideally + # shouldn't be part of the test + magic_code=None + ) + + @pytest.mark.asyncio + async def test_begin_flow_long_case(self, mocker, turn_context): + # mock + dummy_sign_in_resource = SignInResource( + sign_in_link="https://example.com/signin", + token_exchange_state=TokenExchangeState(connection_name="test_connection") + ) + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) + user_token_client.agent_sign_in.get_sign_in_resource = pytest.AsyncMock(return_value=dummy_sign_in_resource) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + ) + response = await flow.begin_flow(turn_context) + + # verify flow_state + flow_state = flow.flow_state + assert flow_state.tag == FlowStateTag.BEGIN + assert flow_state.token == "" + assert flow_state.flow_started is True + + # verify FlowResponse + assert response.flow_state == flow_state + assert response.sign_in_resource == dummy_sign_in_resource + assert response.flow_error_tag == FlowErrorTag.NONE + assert not response.token_response + # robrandao: TODO more assertions on sign_in_resource + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker", "turn_context, flow_state", + [ + ("mocker", "turn_context", FlowState( + tag=FlowStateTag.BEGIN, + token="", + expires=datetime.now().timestamp() - 1, + attempts_remaining=3 + )), + ("mocker", "turn_context", FlowState( + tag=FlowStateTag.CONTINUE, + token="", + expires=datetime.now().timestamp() + 1000, + attempts_remaining=0 + )), + ("mocker", "turn_context", FlowState( + tag=FlowStateTag.FAILED, + token="", + expires=datetime.now().timestamp() + 1000, + attempts_remaining=3 + )), + ("mocker", "turn_context", FlowState( + tag=FlowStateTag.COMPLETED, + token="", + expires=datetime.now().timestamp() + 1000, + attempts_remaining=2 + )), + ], + indirect=["mocker", "turn_context"] + ) + async def test_continue_flow_not_active(self, mocker, turn_context, flow_state): + user_token_client = mocker.Mock() + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + flow_state=flow_state + ) + flow_response = await flow.continue_flow(turn_context) + assert flow_response.flow_state == flow_state + assert not flow_response.token_response + + @pytest.fixture(params=[ + (FlowStateTag.ACTIVE, "test_token", 2), + (FlowStateTag.BEGIN, "", 1), + ]) + def active_flow_state(self, request): + tag, token, attempts_remaining = request.param + return FlowState( + tag=tag, + token=token, + expires=datetime.now().timestamp() + 1000, + attempts_remaining=attempts_remaining + ) + + async def test_continue_flow_message(self, mocker, turn_context, active_flow_state): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.text = "magic-message" + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) + user_token_client.agent_sign_in.get_sign_in_resource = pytest.AsyncMock(return_value=dummy_sign_in_resource) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=mocker.Mock(), + flow_state=active_flow_state + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, turn_context, active_flow_state, magic_code", + [ + ("mocker", "turn_context", "active_flow_state", "magic-message"), + ("mocker", "turn_context", "active_flow_state", ""), + ("mocker", "turn_context", "active_flow_state", "abcdef"), + ("mocker", "turn_context", "active_flow_state", "@#0324"), + ("mocker", "turn_context", "active_flow_state", "231"), + ("mocker", "turn_context", "active_flow_state", None), + ], + indirect=["mocker", "turn_context", "active_flow_state"] + ) + async def test_continue_flow_message_format_error(self, mocker, turn_context, active_flow_state, magic_code): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.text = magic_code + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=mocker.Mock(), + flow_state=active_flow_state + ) + flow_response = flow.continue_flow(turn_context) + + # verify + assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining + assert not flow_response.token_response + assert flow_response.tag == FlowStateTag.FAILURE + assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT + + @pytest.mark.asyncio + async def test_continue_flow_message_magic_code_error(self, mocker, turn_context, active_flow_state): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.text = "123456" + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + flow_state=active_flow_state + ) + flow_response = await flow.continue_flow(turn_context) + + # verify + assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining + assert not flow_response.token_response + assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_CODE + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + magic_code="123456" + ) + + @pytest.mark.asyncio + async def test_continue_flow_invoke_verify_state(self, mocker, turn_context, active_flow_state): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.name = "signin/verifyState" + turn_context.activity.value = {"state": "987654"} + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse(token="some-token")) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + flow_state=active_flow_state + ) + flow_response = await flow.continue_flow(turn_context) + + # verify + assert active_flow_state.attempts_remaining == flow_response.flow_state.attempts_remaining + assert flow_response.token_response.token == "some-token" + assert flow_response.flow_state.tag == FlowStateTag.COMPLETE + assert flow_response.flow_error_tag == FlowErrorTag.NONE + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + magic_code="987654" + ) + + async def test_continue_flow_invoke_verify_state_no_token(self, mocker, turn_context, active_flow_state): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.name = "signin/verifyState" + turn_context.activity.value = {"state": "987654"} + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + flow_state=active_flow_state + ) + flow_response = await flow.continue_flow(turn_context) + + # verify + assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining + assert not flow_response.token_response.token + if active_flow_state.attempts_remaining == 1: + assert flow_response.flow_state.tag == FlowStateTag.FAILURE + else: + assert flow_response.flow_state.tag == FlowStateTag.CONTINUE + assert flow_response.flow_error_tag == FlowErrorTag.UNKNOWN + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + magic_code="987654" + ) + + @pytest.mark.asyncio + async def test_continue_flow_invoke_token_exchange(self, mocker, turn_context, active_flow_state): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.name = "signin/exchangeState" + turn_context.activity.value = "request_body" + user_token_client = mocker.Mock() + user_token_client.user_token.exchange_token = pytest.AsyncMock(return_value=TokenResponse(token="exchange-token")) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + flow_state=active_flow_state + ) + flow_response = await flow.continue_flow(turn_context) + + # verify + assert active_flow_state.attempts_remaining == flow_response.flow_state.attempts_remaining + assert flow_response.token_response.token == "exchange-token" + assert flow_response.flow_state.tag == FlowStateTag.COMPLETE + assert flow_response.flow_error_tag == FlowErrorTag.NONE + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + body="request_body" + ) + + @pytest.mark.asyncio + async def test_continue_flow_invoke_token_exchange_no_token(self, mocker, turn_context, active_flow_state): + # mock + turn_context.activity.type = ActivityTypes.message + turn_context.activity.name = "signin/exchangeState" + turn_context.activity.value = "request_body" + user_token_client = mocker.Mock() + user_token_client.user_token.exchange_token = pytest.AsyncMock(return_value=TokenResponse()) + + # test + flow = AuthFlow( + abs_oauth_connection_name="test_connection", + user_token_client=user_token_client, + flow_state=active_flow_state + ) + flow_response = await flow.continue_flow(turn_context) + + # verify + assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining + assert not flow_response.token_response + if active_flow_state.attempts_remaining == 1: + assert flow_response.flow_state.tag == FlowStateTag.FAILURE + else: + assert flow_response.flow_state.tag == FlowStateTag.CONTINUE + assert flow_response.flow_error_tag == FlowErrorTag.UNKNOWN + assert user_token_client.user_token.get_token.called_once_with( + user_id="__user_id", + connection_name="test_connection", + channel_id="__channel_id", + body="request_body" + ) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow(self): + assert True # robrandao: TODO \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client.py similarity index 56% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client.py index 07cf8adb..f210763a 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client.py @@ -5,11 +5,17 @@ from .models import FlowState # robrandao: TODO -> context.activity.from_property -class FlowStorageClient: +class OAuthStorageClient: """ Wrapper around storage that manages sign-in state specific to each user and channel. Uses the activity's channel_id and from.id to create a key prefix for storage operations. + + Contract with other classes (usage of other classes is enforced in unit tests): + TurnContext.activity.channel_id + TurnContext.activity.from_property.id + + Storage: read(), write(), delete() """ def __init__( @@ -17,6 +23,13 @@ def __init__( context: TurnContext, storage: Storage ): + """ + Parameters + context: The TurnContext for the current conversation. Used to isolate + data across channels and users. This defines the prefix used to + access storage. + storage: The Storage instance used to persist flow state data. + """ if ( not context.activity @@ -24,7 +37,6 @@ def __init__( or not context.activity.from_property or not context.activity.from_property.id ): - raise ValueError("context.activity -> channel_id and from.id must be set.") channel_id = context.activity.channel_id @@ -33,19 +45,26 @@ def __init__( self.__base_key = f"auth/{channel_id}/{user_id}" self.__storage = storage - def __key(self, id: str) -> str: + @property + def base_key(self) -> str: + return self.__base_key + + def key(self, id: str) -> str: """Creates a storage key for a specific sign-in handler.""" return f"{self.__base_key}/${id}" async def read(self, auth_handler_id: str) -> Optional[FlowState]: - key: str = self.__key(auth_handler_id) + """Reads the flow state for a specific authentication handler.""" + key: str = self.key(auth_handler_id) data = await self.__storage.read([key], FlowState) - return data.get(key) # robrandao: TODO -> verify contract + return FlowState.validate(data.get(key)) # robrandao: TODO -> verify contract async def write(self, value: FlowState) -> None: - key: str = self.__key(value.id) + """Saves the flow state for a specific authentication handler.""" + key: str = self.key(value.id) await self.__storage.write({key: value}) async def delete(self, auth_handler_id: str) -> None: - key: str = self.__key(auth_handler_id) + """Deletes the flow state for a specific authentication handler.""" + key: str = self.key(auth_handler_id) await self.__storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client_test.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client_test.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client_test.py From 4949b07debea9fa576ab2bfb33f7f257d28a667b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Tue, 19 Aug 2025 14:28:55 -0700 Subject: [PATCH 07/32] Adding more tests --- .../agents/hosting/core/app/oauth/__init__.py | 18 +- .../hosting/core/app/oauth/auth_context.py | 222 ------------ .../app/oauth/{oauth_flow.py => auth_flow.py} | 3 +- .../hosting/core/app/oauth/auth_handler.py | 1 - .../hosting/core/app/oauth/authorization.py | 138 +++++--- .../core/app/oauth/authorization_test.py | 0 ...orage_client.py => flow_storage_client.py} | 2 +- .../app/oauth/tests/authorization_test.py | 331 ++++++++++++++++++ .../flow_storage_client_test.py} | 45 +-- .../core/app/oauth/{ => tests}/models_test.py | 0 .../app/oauth/{ => tests}/oauth_flow_test.py | 0 11 files changed, 455 insertions(+), 305 deletions(-) delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{oauth_flow.py => auth_flow.py} (99%) delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization_test.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{oauth_storage_client.py => flow_storage_client.py} (98%) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{oauth_storage_client_test.py => tests/flow_storage_client_test.py} (77%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{ => tests}/models_test.py (100%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/{ => tests}/oauth_flow_test.py (100%) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py index ff280c7f..e5fc952f 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py @@ -2,7 +2,21 @@ Authorization, AuthorizationHandlers, AuthHandler, - SignInState, + FlowState, +) +from .models import ( + FlowState, + FlowStateTag, + FlowStateError, + FlowResponse, ) -__all__ = ["Authorization", "AuthorizationHandlers", "AuthHandler", "SignInState"] +__all__ = [ + "Authorization", + "AuthorizationHandlers", + "AuthHandler", + "FlowState", + "FlowStateTag", + "FlowStateError", + "FlowResponse", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py deleted file mode 100644 index 566d1190..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_context.py +++ /dev/null @@ -1,222 +0,0 @@ -import logging -from typing import Optional, Callable - -from .sign_in_storage import SignInStorage, SignInHandlerState, SignInHandlerStateStatus, FlowState - -logger = logging.getLogger(__name__) - -ms_agents_logger = logging.getLogger("microsoft.agents") -handler_formatter = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)")) -ms_agents_logger.addHandler(console_handler) -ms_agents_logger.setLevel(logging.INFO) - -class AuthContext: - - logger = logging.getLogger(f"{__name__}.SignInContext") # robrandao: TODO get logger with config - - def __init__(self, - storage: SignInStorage, - auth_handlers: AuthHandlers, - context: TurnContext, - handler_id: str = "", - is_started_from_route: bool = True): - - if not is_started_from_route and not handler_id: - raise ValueError("handler_id must be provided when is_started_from_route is False.") - - if not hasattr(context, "activity"): # robrandao: TODO -> see extra condition in JS code - raise ValueError("context must have an activity property.") - - self.__storage = AuthStateStorage(context, storage) - - sign_in_storage = SignInStorage( - context, self.__storage, self.__auth_handlers - ) # robrandao: TODO - - # robrandao: TODO -> is this necessary here, or can we do this outside? Can we make the storage outside too? - # if self.is_started_from_route: - - # robrandao: TODO type signature - def on_success(self, handler: Callable) -> None: - self.__on_success_handler = handler - - def on_failure(self, handler: Callable) -> None: - self.__on_failure_handler = handler - - async def get_token(self) -> Optional[TokenResponse]: - - if not await self.load_handler(): - return TokenResponse() - - self.logger.info("Getting token from user token service.") - return self.__auth_handler.flow.get_token(self.context) - - async def exchange_token(self, scopes: list[str]) -> TokenResponse: - if not await self.load_handler(): - return TokenResponse() - - self.logger.info("Exchanging token from user token service.") - token_response = await self.__auth_handler.flow.get_token(self.context) - if self.is_exchangeable(token_response.token): - return await self.handle_obo(token_response.token, scopes) - return token_response - - async def sign_out(self) -> None: - if not await self.load_handler(): - return - - self.logger.info("Signing out from the authorization flow.") - if self.is_started_from_route: - await self.storage.handler_delete(self.handler.id) - return self.__auth_handler.flow.sign_out(self.context) - - async def get_token(self) -> Optional[TokenResponse]: - if not await self.load_handler(): - return TokenResponse() - - self.logger.debug("Processing authorization flow.") - self.logger.debug(f"Uses Storage state: {self.is_started_from_route}") - self.logger.debug("Current sign-in state:", self.handler) - - token_response = await self.handler.status( - {} # robrandao: TODO - ) - - self.logger.debug("OAuth flow result: %s", { token: token_response.get(token), state: self.handler}) - return token_response - - DEFAULT_STATES: dict[SignInHandlerStateStatus, FlowState] = { - "begin": lambda: { id: self.handler_id, status: status} - } - - def __set_status(status: SignInHandlerStateStatus) -> None: - - if status == FlowProgression.BEGIN: - self.flow_state = FlowState(id) - elif status == FlowProgression.CONTINUE: - pass - elif status == FlowProgression.SUCCESS: - pass - elif status == FlowProgression.FAILURE: - pass - - # robrandao: TODO - type - state_builder: dict[str, Callable] = { - SignInHandlerStateStatus.BEGIN: (lambda: { - "id": self.handler_id, - "status": SignInHandlerStateStatus.BEGIN - }), - SignInHandlerStateStatus.CONTINUE: (lambda self: { - **self.handler, - "status": SignInHandlerStateStatus.CONTINUE, - "state": self.flow_state - "continuation_activity": self.context.activity - }), - SignInHandlerStatus.SUCCESS: (lambda: { - **self.handler, - "status": SignInHandlerStateStatus.SUCCESS, - "state": None - }), - SignInHandlerStateStatus.FAILURE: (lambda: { - **self.handler, - "status": SignInHandlerStateStatus.FAILURE, - "state": self.flow_state - }) - } - - self.__handler = state_builder[status]() - return self.__handler # robrandao: TODO ??? - - async def __load_handler(self) -> bool: - if self.is_started_from_route: - if self.handler_id: - self.__handler = await self.storage.handler_get(self.handler_id) - else: - self.__handler = await self.storage.handler_active() - - if not self.__handler: - # robrandao: TODO renaming? - self.__handler = self.__set_status(SignInHandlerStateStatus.NOT_STARTED, None) - - if not self.handler.id: - return False - - self.__auth_handler = self.get_auth_handler_or_throw(self.handler.id) - - # robrandao: TODO - if not self.is_started_from_route and self.flow_state.flow_started: - self.set_status(SignInHandlerStateStatus.SUCCESS) - self.logger.debug("OAuth flow success, using existing state.") - return True - - if self.handler.status == SignInHandlerStateStatus.BEGIN: - self.logger.debug("No active flow state, starting a new OAuth flow.") - await self.__auth_handler.flow.sign_out(self.context) - else: - await self.__auth_handler.flow.set_flow_state(self.context, self.handler.state or FlowState()) # robrandao: TODO - - return True - - async def begin(self) -> None: - self.logger.debug("Beginning OAuth flow.") - await self.__auth_handler.flow.begin_flow(self.context) - self.logger.debug("OAuth flow started, waiting on continuation...") - self.__set_status(SignInHandlerStateStatus.CONTINUE) - if self.is_started_from_route: - await self.storage.handler_set(self.handler) - - async def continue(self) -> Optional[TokenResponse]: - - self.logger.debug("Continuing OAuth flow.") - - token_response = await self.__auth_handler.flow.continue_flow(self.context) - if token_response.token: - self.set_status(SignInHandlerStateStatus.SUCCESS) - self.logger.debug("OAuth flow success.") - if self.is_started_from_route: - await self.storage.handler_set(self.handler) - if self.__on_success_handler: - await self.__on_success_handler() - - else: - await self.failure() - - return token_response - - async def success(self) -> Optional[TokenResponse]: - token_response = await self.__auth_handler.flow.get_token(self.context) - # robrandao: TODO -> JS always strips() the token? - if self.is_started_from_route and token_response.token: - self.logger.debug("OAuth flow success, retrieving token.") - return token_response - else: - self.logger.debug("OAuth flow token not available, waiting on continuation...") - return self.continue() - - async def failure(self) -> None: - self.__set_status(SignInHandlerStateStatus.FAILURE) - - # TODO - - async def __is_exchangeable(self, token: Optional[str]) -> bool: - if not token or not isinstance(token, str): # robrandao: TODO ??? - return False - - payload = JwtToken.parse(token).payload # robrandao: TODO - return payload.aud.index("api://") == 0 - - async def __handle_obo(self, token: str, scopes: list[str]) -> TokenResponse: - msal_token_provider = MsalTokenProvider() - - auth_config = self.context.adapter.auth_config - if self.__auth_handler.cnx_prefix: - auth_config = load_auth_config_from_env(self.__auth_handler.cnx_prefix) - - new_token = await msal_token_provider.get_on_behalf_of_token(auth_config, scopes, token) - return TokenResponse(token) - - async def __get_auth_handler_or_throw(handler_id: str) -> AuthHandler: - # robrandao: TODO - pass - diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py similarity index 99% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py index 4e0a403c..9085f1a0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -class OAuthFlow: +class AuthFlow: """ Manages the OAuth flow. @@ -142,7 +142,6 @@ async def begin_flow(self, context: TurnContext) -> FlowResponse: return FlowResponse(flow_state=self.flow_state, sign_in_resource=sign_in_resource) async def __continue_from_message(self, context: TurnContext) -> tuple[TokenResponse, FlowErrorTag]: - magic_code: str = context.activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: token_response: TokenResponse = await self.__get_user_token(context, magic_code) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py index c6d6697e..234b2da6 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py @@ -30,7 +30,6 @@ def __init__( self.obo_connection_name = obo_connection_name or kwargs.get( "OBOCONNECTIONNAME" ) - self.flow: OAuthFlow = None logger.debug( f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" ) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 454963b2..54d44942 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -6,6 +6,7 @@ import jwt from typing import Dict, Optional, Callable, Awaitable from collections.abc import Iterable +from contextlib import asynccontextmanager from microsoft.agents.hosting.core.authorization import ( Connections, @@ -14,18 +15,18 @@ from microsoft.agents.hosting.core.storage import Storage from microsoft.agents.activity import TokenResponse, Activity from microsoft.agents.hosting.core.storage import StoreItem +from microsoft.agents.hosting.core.connector.client import UserTokenClient from pydantic import BaseModel from ...turn_context import TurnContext from ...app.state.turn_state import TurnState -from ...oauth_flow import OAuthFlow +# from ...oauth_flow import AuthFlow from ...state.user_state import UserState -from .sign_in_context import SignInContext from .auth_handler import AuthHandler, AuthorizationHandlers -from .sign_in_state import SignInState -from .storage import FlowStorageClient -from .models import FlowResponse, FlowState, FlowStateTag +from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag +from .flow_storage_client import FlowStorageClient +from .auth_flow import AuthFlow logger = logging.getLogger(__name__) @@ -61,19 +62,18 @@ def __init__( user_state = UserState(storage) - self.__auth_storage = AuthStateStorage(storage, ) - + self._storage = storage self._connection_manager = connection_manager auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( "USERAUTHORIZATION", {} ) - self._auto_signin = ( - auto_signin - if auto_signin is not None - else auth_configuration.get("AUTOSIGNIN", False) - ) + # self._auto_signin = ( + # auto_signin + # if auto_signin is not None + # else auth_configuration.get("AUTOSIGNIN", False) + # ) handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") if not auth_handlers and handlers_config: @@ -92,21 +92,52 @@ def __init__( Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] ] = None - # Configure each auth handler - for auth_handler in self._auth_handlers.values(): - # Create OAuth flow with configuration - messages_config = {} - if auth_handler.title: - messages_config["card_title"] = auth_handler.title - if auth_handler.text: - messages_config["button_text"] = auth_handler.text - - logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") - auth_handler.flow = OAuthFlow( - storage=storage, - abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, - messages_configuration=messages_config if messages_config else None, - ) + # # Configure each auth handler + # for auth_handler in self._auth_handlers.values(): + # # Create OAuth flow with configuration + # messages_config = {} + # if auth_handler.title: + # messages_config["card_title"] = auth_handler.title + # if auth_handler.text: + # messages_config["button_text"] = auth_handler.text + + # logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") + # auth_handler.flow = AuthFlow( + # storage=storage, + # abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, + # messages_configuration=messages_config if messages_config else None, + # ) + + async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlow: + user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) # robrandao: TODO + auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) + + flow_storage_client = FlowStorageClient(context, self._storage) + flow_state: FlowState = await flow_storage_client.read(auth_handler_id) + + flow = AuthFlow(auth_handler.abs_oauth_connection_name, user_token_client, flow_state) + return flow, flow_storage_client, flow_state + + @asynccontextmanager + async def open_flow(self, context: TurnContext, auth_handler_id: str = "", readonly: bool = False) -> FlowResponse: + """ + Starts the OAuth flow for a specific auth handler. + + Args: + context: The context object for the current turn. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The flow response from the OAuth provider. + """ + if not context or not auth_handler_id: + raise ValueError("context and auth_handler_id are required") + + flow, flow_storage_client, init_flow_state = self.__load_flow(context, auth_handler_id) + yield flow + + if not readonly and flow.flow_state != init_flow_state: + flow_storage_client.write(flow.flow_state) async def get_token( self, context: TurnContext, auth_handler_id: Optional[str] = None @@ -121,8 +152,8 @@ async def get_token( Returns: The token response from the OAuth provider. """ - flow = OAuthFlow(context, auth_handler_id) - return await flow.get_token() + async with self.open_flow(context, auth_handler_id) as flow: + return await flow.get_user_token(context) async def exchange_token( self, @@ -141,14 +172,13 @@ async def exchange_token( Returns: The token response from the OAuth provider. """ - flow = OAuthFlow(context, auth_handler_id) - - token_response = await flow.get_user_token(context) + async with self.open_flow(context, auth_handler_id) as flow: + token_response = await flow.get_user_token(context) - if self.__is_exchangeable(token_response.token if token_response else None): - pass + if token_response and self.__is_exchangeable(token_response.token): + return await self.__handle_obo(token_response.token, scopes, auth_handler_id) - return await flow.exchange_token(scopes) + return TokenResponse() # auth_handler = self.resolver_handler(auth_handler_id) # if not auth_handler.flow: @@ -162,7 +192,7 @@ async def exchange_token( # return token_response - def _is_exchangeable(self, token: Optional[str]) -> bool: + def __is_exchangeable(self, token: Optional[str]) -> bool: """ Checks if a token is exchangeable (has api:// audience). @@ -184,7 +214,7 @@ def _is_exchangeable(self, token: Optional[str]) -> bool: logger.exception("Failed to decode token to check audience") return False - async def _handle_obo( + async def __handle_obo( self, token: str, scopes: list[str], handler_id: str = None ) -> TokenResponse: """ @@ -221,9 +251,8 @@ async def _handle_obo( scopes=scopes, # Expiration can be set based on the token provider's response ) - async def get_active_flow_state(self, context: TurnContext, turn_state: TurnState) -> Optional[FlowResponse]: + async def get_active_flow_state(self, context: TurnContext, turn_state: TurnState = None) -> Optional[FlowState]: flow_storage_client = FlowStorageClient(context, self.__storage) - for auth_handler_id in self._auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) if flow_state.is_active(): @@ -250,16 +279,15 @@ async def begin_or_continue_flow( """ # robrandao: TODO -> is_started_from_route and sec_route - flow_storage_client = FlowStorageClient(context, self.__storage) - flow = OAuthFlow(context, auth_handler_id) - - flow_response: FlowResponse = flow.begin_or_continue_flow(context) + async with self.open_flow(context, auth_handler_id) as flow: + flow_response: FlowResponse = await flow.begin_or_continue_flow(context) + flow_state: FlowState = flow_response.flow_state if flow_state.tag == FlowStateTag.COMPLETE: - self.__on_sign_in_success_handler(context, turn_state, flow_state.handler.id) + self.__sign_in_success_handler(context, turn_state, flow_state.handler.id) elif flow_state.tag == FlowStateTag.FAILURE: - self.__on_sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) + self.__sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) return flow_response @@ -285,21 +313,19 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: async def __sign_out( self, context: TurnContext, - state: TurnState, auth_handler_ids: Iterable[str] = None, ) -> None: - flow_storage_client = FlowStorageClient(context, self.__storage) for auth_handler_id in auth_handler_ids: - auth_handler = self.resolver_handler(auth_handler_id) - flow_state = flow_storage_client.read(auth_handler.flow_id) - if flow_state: + flow, flow_storage_client, initial_flow_state = self.__load_flow(context, auth_handler_id) + if initial_flow_state: logger.info(f"Signing out from handler: {auth_handler_id}") - await auth_handler.flow.sign_out(context) + await flow.sign_out(context) + flow_storage_client.delete(auth_handler_id) async def sign_out( self, context: TurnContext, - state: TurnState, + _state: TurnState, auth_handler_id: Optional[str] = None, ) -> None: """ @@ -312,9 +338,9 @@ async def sign_out( auth_handler_id: Optional ID of the auth handler to use for sign out. """ if auth_handler_id: - self.__sign_out(context, state, [auth_handler_id]) + self.__sign_out(context, [auth_handler_id]) else: - self.__sign_out(context, state, self._auth_handlers.keys()) + self.__sign_out(context, self._auth_handlers.keys()) def on_sign_in_success( self, @@ -326,7 +352,7 @@ def on_sign_in_success( Args: handler: The handler function to call on successful sign-in. """ - self.__sign_in_handler = handler + self.__sign_in_success_handler = handler def on_sign_in_failure( self, @@ -337,4 +363,4 @@ def on_sign_in_failure( Args: handler: The handler function to call on sign-in failure. """ - self.__sign_in_failure = handler \ No newline at end of file + self.__sign_in_failure_handler = handler \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization_test.py deleted file mode 100644 index e69de29b..00000000 diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py similarity index 98% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py index f210763a..d9a4e7ec 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -5,7 +5,7 @@ from .models import FlowState # robrandao: TODO -> context.activity.from_property -class OAuthStorageClient: +class FlowStorageClient: """ Wrapper around storage that manages sign-in state specific to each user and channel. diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py new file mode 100644 index 00000000..11555e57 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py @@ -0,0 +1,331 @@ +# import datetime + +# import pytest +# from pytest import lazy_fixture + +# from microsoft.agents.activity import ( +# TokenResponse +# ) +# from microsoft.agents.hosting.core import ( +# Authorization, +# MemoryStorage, +# FlowStorageClient, +# FlowState, +# FlowErrorTag, +# FlowStateTag, +# FlowResponse +# ) +# from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline +# from tools.testing_authorization import ( + +# ) + +# STORAGE_SAMPLE_DICT = { +# "user_id": "123", +# "session_id": "abc", +# "auth/channel_id/user_id/expired": FlowState( +# id="expired", +# expires=expired_time, +# attempts_remaining=1, +# tag=FlowStateTag.CONTINUE +# ), +# "auth/teams_id/Bob/no_retries": FlowState( +# id="no_retries", +# expires=valid_time, +# attempts_remaining=0, +# tag=FlowStateTag.FAILURE +# ), +# "auth/channel/Alice/begin": FlowState( +# id="begin", +# expired=valid_time, +# attempts_remaining=3, +# tag=FlowStateTag.BEGIN +# ), +# "auth/channel/Alice/continue": FlowState( +# id="continue", +# expires=valid_time, +# attempts_remaining=2 +# tag=FlowStateTag.CONTINUE +# ), +# "auth/channel/Alice/expired_and_retries": FlowState( +# id="expired_and_retries" +# expires=expired_time, +# attempts_remaining=0 +# tag=FlowStateTag.FAILURE +# ), +# "auth/channel/Alice/not_started": FlowState( +# id="not_started", +# tag=FlowStateTag.NOT_STARTED +# ) +# } + +# class TestAuthorization: + +# def build_context(self, mocker, channel_id, from_property_id): +# turn_context = mocker.Mock() +# turn_context.activity.channel_id = channel_id +# turn_context.activity.from_property.id = from_property_id +# return turn_context + +# @pytest.fixture +# def context(self, mocker): +# return self.build_context(mocker, "__channel_id", "__user_id") + +# @pytest.fixture +# def valid_time(self): +# return datetime.datetime.now() + 10000 + +# @pytest.fixture +# def expired_time(self): +# return datetime.datetime.now() + +# @pytest.fixture +# def m_storage(self, mocker): +# return mocker.Mock(spec=MemoryStorage) + +# @pytest.fixture +# def m_connection_manager(self, mocker): +# return mocker.Mock(spec=ConnectionManager) + +# @pytest.fixture +# def auth_handler_ids(self): +# return ["expired", "no_retries", "begin", "continue", "expired_and_retries", "not_started"] + +# @pytest.fixture +# def auth_handlers(self, mocker, auth_handler_ids): +# return { +# auth_handler_id: create_test_auth_handler(f"test-{auth_handler_id}") for auth_handler_id in auth_handler_ids +# } + +# @pytest.fixture +# def storage(self, valid_time, expired_time): +# return MemoryStorage(STORAGE_SAMPLE_DICT) + +# @pytest.fixture +# def connection_manager(self): +# pass + +# @pytest.fixture +# def auth_handlers(self): +# pass + +# @pytest.fixture +# def auth(self, storage, connection_manager, auth_handlers): +# return Authorization( +# storage, +# connection_manager, +# auth_handlers, +# auto_signin=True +# ) + +# @pytest.fixture +# def storage(self, mocker): +# return MemoryStorage({ + +# }) + +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "auth, context, auth_handler_id", +# [ +# ("auth", lazy_fixture("context"), ""), +# ("auth", None, "handler"), +# ("auth", None, "") +# ("auth", lazy_fixture("context", "missing_handler")) +# ], +# indirect=["auth"] +# ) +# async def test_open_flow_value_error(self, auth, context, auth_handler_id): +# with pytest.raises(ValueError): +# async with auth.open_flow(context, auth_handler_id): +# pass + +# # async def test_open_flow_storage_readonly_storage_access(self, mocker, context, m_storage, m_connection_manager, m_auth_handlers): +# # # setup +# # m_storage.read.return_value = FlowState() +# # auth = Authorization( +# # m_storage, +# # m_connection_manager, +# # m_auth_handlers +# # ) + +# # # code +# # async with auth.open_flow(context, "handler", readonly=True) as flow: +# # actual_init_flow_state = flow.flow_state + +# # # verify +# # assert actual_init_flow_state is m_storage.read.return_value +# # assert not m_storage.write.called +# # assert not m_storage.delete.called + +# # async def test_open_flow_storage_unchanged_not_readonly_storage_access(self, context, m_storage, m_connection_manager, m_auth_handlers): +# # # setup +# # m_storage.read.return_value = FlowState() +# # auth = Authorization( +# # m_storage, +# # m_connection_manager, +# # m_auth_handlers +# # ) + +# # # code +# # async with auth.open_flow(context, "handler", readonly=False) as flow: +# # # if no change is made to the flow state, then storage should not be updated +# # actual_init_flow_state = flow.flow_state + +# # # verify +# # assert actual_init_flow_state is m_storage.read.return_value +# # assert not m_storage.write.called +# # assert not m_storage.delete.called + +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", +# [ +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), +# ] +# ) +# async def test_open_flow_readonly_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): +# # setup +# storage = MemoryStorage(STORAGE_SAMPLE_DICT) +# baseline = StorageBaseline(STORAGE_SAMPLE_DICT) +# auth = Authorization( +# storage, +# connection_manager, +# auth_handlers +# ) +# context = self.build_context(mocker, channel_id, from_property_id) +# storage_client = FlowStorageClient(context, storage) +# key = storage_client.key(auth_handler_id) +# expected_init_flow_state = storage.read(key, FlowState) + +# # code +# async with auth.open_flow(context, "handler", readonly=True) as flow: +# actual_init_flow_state = flow.flow_state.copy() +# flow.flow_state.id = "garbage" +# flow.flow_state.tag = FlowStateTag.FAILURE +# flow.flow_state.expires = 0 +# flow.flow_state.attempts_remaining = -1 +# actual_final_flow_state = await storage.read([key], FlowState)[key] + +# # verify +# expected_final_flow_state = baseline.read(key, FlowState) +# assert actual_init_flow_state == expected_init_flow_state +# assert actual_final_flow_state == expected_final_flow_state +# assert await baseline.equals(storage) + +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", +# [ +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), +# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), +# ] +# ) +# async def test_open_flow_storage_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): +# # setup +# storage = MemoryStorage(STORAGE_SAMPLE_DICT) +# baseline = StorageBaseline(STORAGE_SAMPLE_DICT) +# auth = Authorization( +# storage, +# connection_manager, +# auth_handlers +# ) +# context = self.build_context(mocker, channel_id, from_property_id) +# storage_client = FlowStorageClient(context, storage) +# key = storage_client.key(auth_handler_id) +# expected_init_flow_state = storage.read(key, FlowState) + +# # code +# async with auth.open_flow(context, "handler") as flow: +# actual_init_flow_state = flow.flow_state.copy() +# flow.flow_state.id = "garbage" +# flow.flow_state.tag = FlowStateTag.FAILURE +# flow.flow_state.expires = 0 +# flow.flow_state.attempts_remaining = -1 + +# # verify +# baseline.write({ +# "auth/channel/Alice/continue": flow.flow_state +# }) +# expected_final_flow_state = baseline.read(key, FlowState) +# assert await baseline.equals(storage) +# assert actual_init_flow_state == expected_init_flow_state +# assert flow.flow_state == expected_final_flow_state + +# @pytest.mark.asyncio +# async def test_get_token(self, mocker, m_storage): +# m_storage.read.return_value = FlowState( +# id="auth_handler", +# tag=FlowStateTag.ACTIVE, +# expires=3600, +# attempts_remaining=3 +# ) +# expected = TokenResponse( +# access_token="access_token", +# refresh_token="refresh_token", +# expires_in=3600 +# ) +# mock_flow = mocker.AsyncMock() +# mock_flow.get_user_token.return_value = expected +# mocker.patch.object("OAuthFlow", "get_token", return_value=expected) +# mocker.patch.object("OAuthFlow", "__init__", return_value=mock_flow) + +# assert await auth.get_token("auth_handler") is expected +# assert mock_flow.get_user_token.called_once() + +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "auth, context, auth_handler_id", +# [ +# (lazy_fixture("auth"), lazy_fixture("context"), "missing-handler"), +# (lazy_fixture("auth"), lazy_fixture("context"), ""), +# (lazy_fixture("auth"), None, "handler") +# ] +# ) +# async def test_get_token_error(self, auth, context, auth_handler_id): +# with pytest.raises(ValueError): +# await auth.get_token(context, auth_handler_id) + +# @pytest.fixture +# def valid_token_response(self): +# return TokenResponse( +# connection_name="connection", +# token="token" +# ) + +# @pytest.fixture +# def invalid_exchange_token(self): +# token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") +# return TokenResponse( +# connection_name="connection" +# token=token +# ) + +# @pytest.mark.asyncio +# @pytest.mark.parametrize +# async def test_exchange_token(self, mocker, auth): + +# mocker.patch.object("OAuthFlow", +# get_user_token=mocker.AsyncMock(return_value=TokenResponse( +# access_token="access_token", +# refresh_token="refresh_token", +# expires_in=3600 +# )) +# ) + + + + + +# @pytest.mark.asyncio +# async def test_exchange_token(self): +# pass \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py similarity index 77% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client_test.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py index e169e753..7cfd8cb7 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_storage_client_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py @@ -1,9 +1,8 @@ -from microsoft.agents.hosting.core import MemoryStorage - import pytest +from pytest import lazy_fixture from microsoft.agents.hosting.core import ( - Storage, + MemoryStorage FlowStorageClient, MockStoreItem, FlowState @@ -18,25 +17,29 @@ def turn_context(self, mocker): context.activity.from_property.id = "__user_id" return context + @pytest.fixture + def storage(self): + return MemoryStorage() + @pytest.fixture def client(self, turn_context, storage): return FlowStorageClient(turn_context, storage) @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context, storage, channel_id, from_property_id", + "mocker, channel_id, from_property_id", [ - ("mocker", "turn_context", "storage", "channel_id", "from_property_id"), - ("mocker", "turn_context", "storage", "teams_id", "Bob"), - ("mocker", "turn_context", "storage", "channel", "Alice"), + ("mocker", "channel_id", "from_property_id"), + ("mocker", "teams_id", "Bob"), + ("mocker", "channel", "Alice"), ], indirect=["mocker", "turn_context", "storage"] ) - async def test_init_base_key(self, mocker, turn_context, storage, channel_id, from_property_id): + async def test_init_base_key(self, mocker, channel_id, from_property_id): context = mocker.Mock() context.activity.channel_id = channel_id context.activity.from_property.id = from_property_id - client = FlowStorageClient(context, storage) + client = FlowStorageClient(context, mocker.Mock()) assert client.base_key == f"auth/{channel_id}/{from_property_id}/" async def test_init_fails_without_from_id(self, mocker, storage): @@ -63,13 +66,13 @@ def test_key(self, client, auth_handler_id, expected): @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context storage, client, auth_handler_id", + "mocker, turn_context, auth_handler_id", [ - (mocker, turn_context, storage, client, "handler"), - (mocker, turn_context, storage, client, "auth_handler"), + (mocker, turn_context, "handler"), + (mocker, turn_context, "auth_handler"), ] ) - async def test_read(self, mocker, turn_context, storage, client, auth_handler_id): + async def test_read(self, mocker, turn_context, auth_handler_id): storage = mocker.AsyncMock() storage.read.return_value = sentinel.read_response client = FlowStorageClient(turn_context, storage) @@ -79,13 +82,13 @@ async def test_read(self, mocker, turn_context, storage, client, auth_handler_id @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context storage, client, auth_handler_id", + "mocker, turn_context, auth_handler_id", [ - (mocker, turn_context, storage, client, "handler", "auth/__channel_id/__user_id/handler"), - (mocker, turn_context, storage, client, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + (lazy_fixture("mocker"), lazy_fixture("turn_context"), "handler", "auth/__channel_id/__user_id/handler"), + (lazy_fixture("mocker"), lazy_fixture("turn_context"), "auth_handler", "auth/__channel_id/__user_id/auth_handler"), ] ) - async def test_write(self, mocker, turn_context, storage, client, auth_handler_id, key, flow_state): + async def test_write(self, mocker, turn_context, auth_handler_id, key, flow_state): storage = mocker.AsyncMock() storage.write.return_value = None client = FlowStorageClient(turn_context, storage) @@ -96,13 +99,13 @@ async def test_write(self, mocker, turn_context, storage, client, auth_handler_i @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context storage, client, auth_handler_id", + "mocker, turn_context, auth_handler_id", [ - (mocker, turn_context, storage, client, "handler", "auth/__channel_id/__user_id/handler"), - (mocker, turn_context, storage, client, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + (mocker, turn_context, "handler", "auth/__channel_id/__user_id/handler"), + (mocker, turn_context, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), ] ) - async def test_delete(self, mocker, turn_context, storage, client, auth_handler_id): + async def test_delete(self, mocker, turn_context, auth_handler_id): storage = mocker.AsyncMock() storage.write.return_value = None client = FlowStorageClient(turn_context, storage) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models_test.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/oauth_flow_test.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py From d64a880e8596236c805763ae2ae72885cafef1c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Tue, 19 Aug 2025 16:39:28 -0700 Subject: [PATCH 08/32] Fixing test cases and catching bugs --- .../agents/hosting/core/app/oauth/__init__.py | 15 ++- .../hosting/core/app/oauth/auth_flow.py | 10 +- .../hosting/core/app/oauth/auth_handler.py | 7 +- .../hosting/core/app/oauth/authorization.py | 40 +++--- .../core/app/oauth/flow_storage_client.py | 9 +- .../agents/hosting/core/app/oauth/models.py | 24 ++-- .../app/oauth/tests/authorization_test.py | 2 +- .../oauth/tests/flow_storage_client_test.py | 97 +++++++------- .../core/app/oauth/tests/models_test.py | 6 +- .../core/app/oauth/tests/oauth_flow_test.py | 118 +++++++++--------- 10 files changed, 167 insertions(+), 161 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py index e5fc952f..5817c861 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py @@ -1,22 +1,25 @@ from .authorization import ( - Authorization, - AuthorizationHandlers, + Authorization +) +from .auth_handler import ( AuthHandler, - FlowState, + AuthorizationHandlers ) from .models import ( FlowState, FlowStateTag, - FlowStateError, + FlowErrorTag, FlowResponse, ) +from .flow_storage_client import FlowStorageClient __all__ = [ "Authorization", - "AuthorizationHandlers", "AuthHandler", + "AuthorizationHandlers", "FlowState", "FlowStateTag", - "FlowStateError", + "FlowErrorTag", "FlowResponse", + "FlowStorageClient", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py index 9085f1a0..92d3733a 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -21,8 +21,6 @@ from microsoft.agents.hosting.core.storage import StoreItem, Storage from pydantic import BaseModel, PositiveInt -from .message_factory import MessageFactory -from .card_factory import CardFactory from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag logger = logging.getLogger(__name__) @@ -76,9 +74,9 @@ def __init__( # # use cached value later # self.__user_token_client = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) - async def __get_ids_or_raise(self, context: TurnContext) -> TokenResponse: + def __get_ids_or_raise(self, context: TurnContext) -> TokenResponse: if ( - not not context.activity.channel_id or + not context.activity.channel_id or not context.activity.from_property or not context.activity.from_property.id ): @@ -97,7 +95,7 @@ async def __get_user_token(self, context: TurnContext, magic_code=None) -> Token ) async def get_user_token(self, context: TurnContext) -> TokenResponse: - return self.__get_user_token(context) + return await self.__get_user_token(context) async def sign_out(self, context: TurnContext) -> None: channel_id, from_id = self.__get_ids_or_raise(context) @@ -122,7 +120,7 @@ async def begin_flow(self, context: TurnContext) -> FlowResponse: # init flow state - token_response = self.get_user_token(context) + token_response = await self.get_user_token(context) if token_response: return FlowResponse( flow_state=self.flow_state, diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py index 234b2da6..203cb08c 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py @@ -1,3 +1,8 @@ +import logging +from typing import Dict + +logger = logging.getLogger(__name__) + class AuthHandler: """ Interface defining an authorization handler for OAuth flows. @@ -34,5 +39,5 @@ def __init__( f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" ) -# Type alias for authorization handlers dictionary +# # Type alias for authorization handlers dictionary AuthorizationHandlers = Dict[str, AuthHandler] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 54d44942..e8483c78 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -22,7 +22,7 @@ from ...app.state.turn_state import TurnState # from ...oauth_flow import AuthFlow from ...state.user_state import UserState -from .auth_handler import AuthHandler, AuthorizationHandlers +from .auth_handler import AuthHandler from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag from .flow_storage_client import FlowStorageClient @@ -41,7 +41,7 @@ def __init__( self, storage: Storage, connection_manager: Connections, - auth_handlers: AuthorizationHandlers = None, + auth_handlers: dict[str, AuthHandler] = None, auto_signin: bool = None, **kwargs, ): @@ -57,10 +57,10 @@ def __init__( """ if not storage: raise ValueError("Storage is required for Authorization") - if not auth_handlers: - raise ValueError("At least one AuthHandler must be provided") + # if not auth_handlers: + # raise ValueError("At least one AuthHandler must be provided") - user_state = UserState(storage) + # user_state = UserState(storage) self._storage = storage self._connection_manager = connection_manager @@ -108,7 +108,7 @@ def __init__( # messages_configuration=messages_config if messages_config else None, # ) - async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlow: + async def _load_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlow: user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) # robrandao: TODO auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) @@ -133,7 +133,7 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str = "", reado if not context or not auth_handler_id: raise ValueError("context and auth_handler_id are required") - flow, flow_storage_client, init_flow_state = self.__load_flow(context, auth_handler_id) + flow, flow_storage_client, init_flow_state = self._load_flow(context, auth_handler_id) yield flow if not readonly and flow.flow_state != init_flow_state: @@ -175,8 +175,8 @@ async def exchange_token( async with self.open_flow(context, auth_handler_id) as flow: token_response = await flow.get_user_token(context) - if token_response and self.__is_exchangeable(token_response.token): - return await self.__handle_obo(token_response.token, scopes, auth_handler_id) + if token_response and self._is_exchangeable(token_response.token): + return await self._handle_obo(token_response.token, scopes, auth_handler_id) return TokenResponse() @@ -192,7 +192,7 @@ async def exchange_token( # return token_response - def __is_exchangeable(self, token: Optional[str]) -> bool: + def _is_exchangeable(self, token: Optional[str]) -> bool: """ Checks if a token is exchangeable (has api:// audience). @@ -214,7 +214,7 @@ def __is_exchangeable(self, token: Optional[str]) -> bool: logger.exception("Failed to decode token to check audience") return False - async def __handle_obo( + async def _handle_obo( self, token: str, scopes: list[str], handler_id: str = None ) -> TokenResponse: """ @@ -252,7 +252,7 @@ async def __handle_obo( ) async def get_active_flow_state(self, context: TurnContext, turn_state: TurnState = None) -> Optional[FlowState]: - flow_storage_client = FlowStorageClient(context, self.__storage) + flow_storage_client = FlowStorageClient(context, self._storage) for auth_handler_id in self._auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) if flow_state.is_active(): @@ -285,9 +285,9 @@ async def begin_or_continue_flow( flow_state: FlowState = flow_response.flow_state if flow_state.tag == FlowStateTag.COMPLETE: - self.__sign_in_success_handler(context, turn_state, flow_state.handler.id) + self._sign_in_success_handler(context, turn_state, flow_state.handler.id) elif flow_state.tag == FlowStateTag.FAILURE: - self.__sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) + self._sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) return flow_response @@ -310,13 +310,13 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: # Return the first handler if no ID specified return next(iter(self._auth_handlers.values)) - async def __sign_out( + async def _sign_out( self, context: TurnContext, auth_handler_ids: Iterable[str] = None, ) -> None: for auth_handler_id in auth_handler_ids: - flow, flow_storage_client, initial_flow_state = self.__load_flow(context, auth_handler_id) + flow, flow_storage_client, initial_flow_state = self._load_flow(context, auth_handler_id) if initial_flow_state: logger.info(f"Signing out from handler: {auth_handler_id}") await flow.sign_out(context) @@ -338,9 +338,9 @@ async def sign_out( auth_handler_id: Optional ID of the auth handler to use for sign out. """ if auth_handler_id: - self.__sign_out(context, [auth_handler_id]) + self._sign_out(context, [auth_handler_id]) else: - self.__sign_out(context, self._auth_handlers.keys()) + self._sign_out(context, self._auth_handlers.keys()) def on_sign_in_success( self, @@ -352,7 +352,7 @@ def on_sign_in_success( Args: handler: The handler function to call on successful sign-in. """ - self.__sign_in_success_handler = handler + self._sign_in_success_handler = handler def on_sign_in_failure( self, @@ -363,4 +363,4 @@ def on_sign_in_failure( Args: handler: The handler function to call on sign-in failure. """ - self.__sign_in_failure_handler = handler \ No newline at end of file + self._sign_in_failure_handler = handler \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py index d9a4e7ec..30d35a9d 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -1,6 +1,7 @@ from typing import Optional -from microsoft.agents.hosting.core import TurnContext, Storage +from ... import TurnContext +from ...storage import Storage from .models import FlowState @@ -42,7 +43,7 @@ def __init__( channel_id = context.activity.channel_id user_id = context.activity.from_property.id - self.__base_key = f"auth/{channel_id}/{user_id}" + self.__base_key = f"auth/{channel_id}/{user_id}/" self.__storage = storage @property @@ -51,7 +52,7 @@ def base_key(self) -> str: def key(self, id: str) -> str: """Creates a storage key for a specific sign-in handler.""" - return f"{self.__base_key}/${id}" + return f"{self.__base_key}{id}" async def read(self, auth_handler_id: str) -> Optional[FlowState]: """Reads the flow state for a specific authentication handler.""" @@ -61,7 +62,7 @@ async def read(self, auth_handler_id: str) -> Optional[FlowState]: async def write(self, value: FlowState) -> None: """Saves the flow state for a specific authentication handler.""" - key: str = self.key(value.id) + key: str = self.key(value.auth_handler_id) await self.__storage.write({key: value}) async def delete(self, auth_handler_id: str) -> None: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py index d1a1ac2d..81bf13f1 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py @@ -1,12 +1,12 @@ -import datetime +from datetime import datetime from enum import Enum from typing import Optional from pydantic import BaseModel from pydantic.types import PositiveInt -from microsoft.agents.activity import Activity -from microsoft.agents.hosting.core import StoreItem +from microsoft.agents.activity import Activity, SignInResource, TokenResponse +from microsoft.agents.hosting.core.storage import StoreItem class FlowStateTag(Enum): BEGIN = "begin" @@ -22,17 +22,17 @@ class FlowErrorTag(Enum): class FlowState(BaseModel, StoreItem): - flow_id: NonEmptyString + auth_handler_id: str = "" flow_started: bool = False user_token: str = "" - expires: float = 0 + expires_at: float = 0 abs_oauth_connection_name: Optional[str] = None continuation_activity: Optional[Activity] = None - attempts_remaining: PositiveInt = 3 - tag: FlowStateTag = FlowStateTag.INACTIVE + attempts_remaining: int = 3 + tag: FlowStateTag = FlowStateTag.NOT_STARTED def refresh(self) -> None: - if self.is_expired() or self.reached_max_retries(): + if self.is_expired() or self.reached_max_attempts(): self.tag = FlowStateTag.FAILURE def store_item_to_json(self) -> dict: @@ -43,17 +43,17 @@ def from_json_to_store_item(json_data: dict) -> "FlowState": return FlowState.model_validate(json_data) def is_expired(self) -> bool: - return datetime.now().timestamp() >= self.flow_expires - + return datetime.now().timestamp() >= self.expires_at + def reached_max_attempts(self) -> bool: return self.attempts_remaining <= 0 def is_active(self) -> bool: - return not self.is_expired() and not self.reached_max_retries() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] + return not self.is_expired() and not self.reached_max_attempts() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] class FlowResponse(BaseModel): - flow_data: FlowData + flow_state: FlowState = FlowState() flow_error_tag: FlowErrorTag = FlowErrorTag.NONE token_response: Optional[TokenResponse] = None sign_in_resource: Optional[SignInResource] = None diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py index 11555e57..110454bc 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py @@ -1,7 +1,7 @@ # import datetime # import pytest -# from pytest import lazy_fixture +# from pytest_lazyfixture import lazy_fixture # from microsoft.agents.activity import ( # TokenResponse diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py index 7cfd8cb7..dd1d5fac 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py @@ -1,11 +1,11 @@ import pytest -from pytest import lazy_fixture +from unittest.mock import sentinel -from microsoft.agents.hosting.core import ( - MemoryStorage +from microsoft.agents.hosting.core.storage import MemoryStorage +from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem +from microsoft.agents.hosting.core.app.oauth import ( + FlowState, FlowStorageClient, - MockStoreItem, - FlowState ) class TestFlowStorageClient: @@ -27,13 +27,12 @@ def client(self, turn_context, storage): @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, channel_id, from_property_id", + "channel_id, from_property_id", [ - ("mocker", "channel_id", "from_property_id"), - ("mocker", "teams_id", "Bob"), - ("mocker", "channel", "Alice"), + ("channel_id", "from_property_id"), + ("teams_id", "Bob"), + ("channel", "Alice"), ], - indirect=["mocker", "turn_context", "storage"] ) async def test_init_base_key(self, mocker, channel_id, from_property_id): context = mocker.Mock() @@ -42,23 +41,27 @@ async def test_init_base_key(self, mocker, channel_id, from_property_id): client = FlowStorageClient(context, mocker.Mock()) assert client.base_key == f"auth/{channel_id}/{from_property_id}/" + @pytest.mark.asyncio async def test_init_fails_without_from_id(self, mocker, storage): with pytest.raises(ValueError): context = mocker.Mock() context.activity.channel_id = "channel_id" + context.activity.from_property = mocker.Mock(id=None) FlowStorageClient(context, storage) + @pytest.mark.asyncio async def test_init_fails_without_channel_id(self, mocker, storage): with pytest.raises(ValueError): context = mocker.Mock() + context.activity.channel_id = None context.activity.from_property.id = "from_id" FlowStorageClient(context, storage) @pytest.mark.parametrize( - "client, auth_handler_id, expected", + "auth_handler_id, expected", [ - (client, "handler", "auth/__channel_id/__user_id/handler"), - (client, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ("handler", "auth/__channel_id/__user_id/handler"), + ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), ] ) def test_key(self, client, auth_handler_id, expected): @@ -66,56 +69,58 @@ def test_key(self, client, auth_handler_id, expected): @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context, auth_handler_id", + "auth_handler_id", [ - (mocker, turn_context, "handler"), - (mocker, turn_context, "auth_handler"), + ("handler",), + ("auth_handler",), ] ) async def test_read(self, mocker, turn_context, auth_handler_id): storage = mocker.AsyncMock() - storage.read.return_value = sentinel.read_response + key = f"auth/__channel_id/__user_id/{auth_handler_id}" + storage.read.return_value = {key: FlowState()} client = FlowStorageClient(turn_context, storage) res = await client.read(auth_handler_id) - assert res == storage.read.return_value - assert storage.read.called_once_with([f"auth/__channel_id/__user_id/{auth_handler_id}"], FlowState) + assert res is storage.read.return_value[key] + storage.read.assert_called_once_with([f"auth/__channel_id/__user_id/{auth_handler_id}"], FlowState) @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context, auth_handler_id", + "auth_handler_id, key", [ - (lazy_fixture("mocker"), lazy_fixture("turn_context"), "handler", "auth/__channel_id/__user_id/handler"), - (lazy_fixture("mocker"), lazy_fixture("turn_context"), "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ("handler", "auth/__channel_id/__user_id/handler"), + ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), ] ) - async def test_write(self, mocker, turn_context, auth_handler_id, key, flow_state): + async def test_write(self, mocker, turn_context, auth_handler_id, key): storage = mocker.AsyncMock() storage.write.return_value = None client = FlowStorageClient(turn_context, storage) flow_state = mocker.Mock(spec=FlowState) - flow_state.id = auth_handler_id + flow_state.auth_handler_id = auth_handler_id await client.write(flow_state) - assert storage.write.called_once_with({ key: flow_state }) + storage.write.assert_called_once_with({ key: flow_state }) @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context, auth_handler_id", + "auth_handler_id, key", [ - (mocker, turn_context, "handler", "auth/__channel_id/__user_id/handler"), - (mocker, turn_context, "auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ("handler", "auth/__channel_id/__user_id/handler"), + ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), ] ) - async def test_delete(self, mocker, turn_context, auth_handler_id): + async def test_delete(self, mocker, turn_context, auth_handler_id, key): storage = mocker.AsyncMock() - storage.write.return_value = None + storage.delete.return_value = None client = FlowStorageClient(turn_context, storage) await client.delete(auth_handler_id) - assert storage.write.called_once_with([auth_handler_id]) + storage.delete.assert_called_once_with([client.key(auth_handler_id)]) + @pytest.mark.asyncio async def test_integration_with_memory_storage(self, turn_context): - flow_state_alpha = FlowState(flow_id="handler", flow_started=True) - flow_state_beta = FlowState(flow_id="auth_handler", flow_started=True, user_token="token") + flow_state_alpha = FlowState(auth_handler_id="handler", flow_started=True) + flow_state_beta = FlowState(auth_handler_id="auth_handler", flow_started=True, user_token="token") storage = MemoryStorage({ "some_data": MockStoreItem({"value": "test"}), @@ -144,22 +149,22 @@ async def delete_both(*args, **kwargs): client = FlowStorageClient(turn_context, storage) - new_flow_state_alpha = FlowState(flow_id="handler") - flow_state_chi = FlowState(flow_id="chi") + new_flow_state_alpha = FlowState(auth_handler_id="handler") + flow_state_chi = FlowState(auth_handler_id="chi") await client.write(new_flow_state_alpha) await client.write(flow_state_chi) - baseline.write({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) - baseline.write({"auth/__channel_id/__user_id/chi": flow_state_chi.copy()}) + await baseline.write({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) + await baseline.write({"auth/__channel_id/__user_id/chi": flow_state_chi.copy()}) - write_both({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) - write_both({"auth/__channel_id/__user_id/auth_handler": flow_state_beta.copy()}) - write_both({"other_data": MockStoreItem({"value": "more"}).copy()}) + await write_both({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) + await write_both({"auth/__channel_id/__user_id/auth_handler": flow_state_beta.copy()}) + await write_both({"other_data": MockStoreItem({"value": "more"})}) - delete_both(["some_data"]) + await delete_both(["some_data"]) - assert read_check(["auth/__channel_id/__user_id/handler"], FlowState) - assert read_check(["auth/__channel_id/__user_id/auth_handler"], FlowState) - assert read_check(["auth/__channel_id/__user_id/chi"], FlowState) - assert read_check(["other_data"], MockStoreItem) - assert read_check(["some_data"], MockStoreItem) + await read_check(["auth/__channel_id/__user_id/handler"], target_cls=FlowState) + await read_check(["auth/__channel_id/__user_id/auth_handler"], target_cls=FlowState) + await read_check(["auth/__channel_id/__user_id/chi"], target_cls=FlowState) + await read_check(["other_data"], target_cls=MockStoreItem) + await read_check(["some_data"], target_cls=MockStoreItem) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py index 05d83780..e9a03b5f 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py @@ -1,4 +1,4 @@ -import datetime +from datetime import datetime import pytest @@ -76,7 +76,7 @@ def test_is_expired(self, flow_state, expected): True), ] ) - def test_reached_max_attempts(flow_state, expected): + def test_reached_max_attempts(self, flow_state, expected): assert flow_state.reached_max_attempts() == expected @pytest.mark.parametrize( @@ -102,5 +102,5 @@ def test_reached_max_attempts(flow_state, expected): True) ] ) - def test_is_active(flow_state, expected): + def test_is_active(self, flow_state, expected): assert flow_state.is_active() == expected \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py index f54e67ee..140beb0d 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py @@ -1,10 +1,13 @@ +from datetime import datetime + import pytest from microsoft.agents.activity import ( ActivityTypes, - TokenResponse + TokenResponse, + SignInResource ) -from microsoft.agents.hosting.core import AuthFlow +from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow from microsoft.agents.hosting.core.app.oauth.models import ( FlowErrorTag, @@ -22,24 +25,24 @@ def turn_context(self, mocker): context.activity.from_property.id = "__user_id" return context - def test_init_no_state(self): - flow = AuthFlow() + def test_init_no_state(self, mocker, turn_context): + flow = AuthFlow(turn_context, mocker.Mock()) assert flow.flow_state == FlowState() - def test_init_with_state(self): + def test_init_with_state(self, mocker, turn_context): flow_state = FlowState( tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp() + 10000 ) - flow = AuthFlow(flow_state=flow_state) + flow = AuthFlow(turn_context, mocker.Mock(), flow_state=flow_state) assert flow.flow_state == flow_state @pytest.mark.asyncio - async def test_get_user_token(self, turn_context): + async def test_get_user_token(self, mocker, turn_context): # mock - user_token_client = pytest.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value="test_token") + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = mocker.AsyncMock(return_value="test_token") # test flow = AuthFlow( @@ -50,7 +53,7 @@ async def test_get_user_token(self, turn_context): # verify assert token == "test_token" - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", @@ -58,10 +61,10 @@ async def test_get_user_token(self, turn_context): ) @pytest.mark.asyncio - async def test_sign_out(self, turn_context): + async def test_sign_out(self, mocker, turn_context): # mock - user_token_client = pytest.Mock() - user_token_client.user_token.sign_out = pytest.AsyncMock() + user_token_client = mocker.Mock() + user_token_client.user_token.sign_out = mocker.AsyncMock() # test flow = AuthFlow( @@ -71,18 +74,17 @@ async def test_sign_out(self, turn_context): await flow.sign_out(turn_context) # verify - assert user_token_client.user_token.sign_out.called_once_with( + user_token_client.user_token.sign_out.assert_called_once_with( user_id="__user_id", connection_name="connection", - channel_id="__channel_id", - magic_code=None + channel_id="__channel_id" ) @pytest.mark.asyncio - async def test_begin_flow_easy_case(self): + async def test_begin_flow_easy_case(self, mocker, turn_context): # mock - user_token_client = pytest.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse(token="test_token")) + user_token_client = mocker.Mock() + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse(token="test_token")) # test flow = AuthFlow( @@ -102,7 +104,7 @@ async def test_begin_flow_easy_case(self): assert response.sign_in_resource is None # No sign-in resource in this case assert response.flow_error_tag == FlowErrorTag.NONE assert response.token_response == "test_token" - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", @@ -119,8 +121,8 @@ async def test_begin_flow_long_case(self, mocker, turn_context): token_exchange_state=TokenExchangeState(connection_name="test_connection") ) user_token_client = mocker.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) - user_token_client.agent_sign_in.get_sign_in_resource = pytest.AsyncMock(return_value=dummy_sign_in_resource) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock(return_value=dummy_sign_in_resource) # test flow = AuthFlow( @@ -144,34 +146,33 @@ async def test_begin_flow_long_case(self, mocker, turn_context): @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker", "turn_context, flow_state", + "flow_state", [ - ("mocker", "turn_context", FlowState( + FlowState( tag=FlowStateTag.BEGIN, token="", - expires=datetime.now().timestamp() - 1, + expires_at=datetime.now().timestamp() - 1, attempts_remaining=3 - )), - ("mocker", "turn_context", FlowState( + ), + FlowState( tag=FlowStateTag.CONTINUE, token="", - expires=datetime.now().timestamp() + 1000, + expires_at=datetime.now().timestamp() + 1000, attempts_remaining=0 - )), - ("mocker", "turn_context", FlowState( - tag=FlowStateTag.FAILED, + ), + FlowState( + tag=FlowStateTag.FAILURE, token="", - expires=datetime.now().timestamp() + 1000, + expires_at=datetime.now().timestamp() + 1000, attempts_remaining=3 - )), - ("mocker", "turn_context", FlowState( - tag=FlowStateTag.COMPLETED, + ), + FlowState( + tag=FlowStateTag.COMPLETE, token="", - expires=datetime.now().timestamp() + 1000, + expires_at=datetime.now().timestamp() + 1000, attempts_remaining=2 - )), + ), ], - indirect=["mocker", "turn_context"] ) async def test_continue_flow_not_active(self, mocker, turn_context, flow_state): user_token_client = mocker.Mock() @@ -185,7 +186,7 @@ async def test_continue_flow_not_active(self, mocker, turn_context, flow_state): assert not flow_response.token_response @pytest.fixture(params=[ - (FlowStateTag.ACTIVE, "test_token", 2), + (FlowStateTag.CONTINUE, "test_token", 2), (FlowStateTag.BEGIN, "", 1), ]) def active_flow_state(self, request): @@ -193,7 +194,7 @@ def active_flow_state(self, request): return FlowState( tag=tag, token=token, - expires=datetime.now().timestamp() + 1000, + expires_at=datetime.now().timestamp() + 1000, attempts_remaining=attempts_remaining ) @@ -202,8 +203,8 @@ async def test_continue_flow_message(self, mocker, turn_context, active_flow_sta turn_context.activity.type = ActivityTypes.message turn_context.activity.text = "magic-message" user_token_client = mocker.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) - user_token_client.agent_sign_in.get_sign_in_resource = pytest.AsyncMock(return_value=dummy_sign_in_resource) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock(return_value=dummy_sign_in_resource) # test flow = AuthFlow( @@ -214,16 +215,8 @@ async def test_continue_flow_message(self, mocker, turn_context, active_flow_sta @pytest.mark.asyncio @pytest.mark.parametrize( - "mocker, turn_context, active_flow_state, magic_code", - [ - ("mocker", "turn_context", "active_flow_state", "magic-message"), - ("mocker", "turn_context", "active_flow_state", ""), - ("mocker", "turn_context", "active_flow_state", "abcdef"), - ("mocker", "turn_context", "active_flow_state", "@#0324"), - ("mocker", "turn_context", "active_flow_state", "231"), - ("mocker", "turn_context", "active_flow_state", None), - ], - indirect=["mocker", "turn_context", "active_flow_state"] + "magic_code", + ["magic-message", "", "abcdef", "@#0324", "231", None] ) async def test_continue_flow_message_format_error(self, mocker, turn_context, active_flow_state, magic_code): # mock @@ -250,7 +243,7 @@ async def test_continue_flow_message_magic_code_error(self, mocker, turn_context turn_context.activity.type = ActivityTypes.message turn_context.activity.text = "123456" user_token_client = mocker.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) # test flow = AuthFlow( @@ -264,7 +257,7 @@ async def test_continue_flow_message_magic_code_error(self, mocker, turn_context assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining assert not flow_response.token_response assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_CODE - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", @@ -278,7 +271,7 @@ async def test_continue_flow_invoke_verify_state(self, mocker, turn_context, act turn_context.activity.name = "signin/verifyState" turn_context.activity.value = {"state": "987654"} user_token_client = mocker.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse(token="some-token")) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse(token="some-token")) # test flow = AuthFlow( @@ -293,20 +286,21 @@ async def test_continue_flow_invoke_verify_state(self, mocker, turn_context, act assert flow_response.token_response.token == "some-token" assert flow_response.flow_state.tag == FlowStateTag.COMPLETE assert flow_response.flow_error_tag == FlowErrorTag.NONE - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", magic_code="987654" ) + @pytest.mark.asyncio async def test_continue_flow_invoke_verify_state_no_token(self, mocker, turn_context, active_flow_state): # mock turn_context.activity.type = ActivityTypes.message turn_context.activity.name = "signin/verifyState" turn_context.activity.value = {"state": "987654"} user_token_client = mocker.Mock() - user_token_client.user_token.get_token = pytest.AsyncMock(return_value=TokenResponse()) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) # test flow = AuthFlow( @@ -324,7 +318,7 @@ async def test_continue_flow_invoke_verify_state_no_token(self, mocker, turn_con else: assert flow_response.flow_state.tag == FlowStateTag.CONTINUE assert flow_response.flow_error_tag == FlowErrorTag.UNKNOWN - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", @@ -338,7 +332,7 @@ async def test_continue_flow_invoke_token_exchange(self, mocker, turn_context, a turn_context.activity.name = "signin/exchangeState" turn_context.activity.value = "request_body" user_token_client = mocker.Mock() - user_token_client.user_token.exchange_token = pytest.AsyncMock(return_value=TokenResponse(token="exchange-token")) + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token="exchange-token")) # test flow = AuthFlow( @@ -353,7 +347,7 @@ async def test_continue_flow_invoke_token_exchange(self, mocker, turn_context, a assert flow_response.token_response.token == "exchange-token" assert flow_response.flow_state.tag == FlowStateTag.COMPLETE assert flow_response.flow_error_tag == FlowErrorTag.NONE - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", @@ -367,7 +361,7 @@ async def test_continue_flow_invoke_token_exchange_no_token(self, mocker, turn_c turn_context.activity.name = "signin/exchangeState" turn_context.activity.value = "request_body" user_token_client = mocker.Mock() - user_token_client.user_token.exchange_token = pytest.AsyncMock(return_value=TokenResponse()) + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse()) # test flow = AuthFlow( @@ -385,7 +379,7 @@ async def test_continue_flow_invoke_token_exchange_no_token(self, mocker, turn_c else: assert flow_response.flow_state.tag == FlowStateTag.CONTINUE assert flow_response.flow_error_tag == FlowErrorTag.UNKNOWN - assert user_token_client.user_token.get_token.called_once_with( + user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", channel_id="__channel_id", From 10d5c0c554f03fc735d551716ce6ff81c7a6faf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Tue, 19 Aug 2025 17:07:10 -0700 Subject: [PATCH 09/32] Another commit --- .../hosting/core/app/oauth/auth_flow.py | 26 +++++++++------ .../core/app/oauth/flow_storage_client.py | 16 ++++----- .../agents/hosting/core/app/oauth/models.py | 3 +- .../oauth/tests/flow_storage_client_test.py | 18 +++++----- .../core/app/oauth/tests/oauth_flow_test.py | 33 ++++++++++++++----- 5 files changed, 60 insertions(+), 36 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py index 92d3733a..981b04a0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -64,8 +64,11 @@ def __init__( "OAuthFlow.__init__: user_token_client required." ) - flow_state = flow_state or FlowState() # robrandao: TODO - self.flow_state = flow_state.copy() + if not flow_state: + self.flow_state = FlowState() + else: + self.flow_state = flow_state.model_copy() + self.__abs_oauth_connection_name = abs_oauth_connection_name self.__user_token_client = user_token_client @@ -112,12 +115,6 @@ async def __use_attempt(self) -> None: async def begin_flow(self, context: TurnContext) -> FlowResponse: - self.flow_state = FlowState( - id=self.__abs_oauth_connection_name, - channel_id=context.activity.channel_id, - user_id=context.activity.from_property.id - ) - # init flow state token_response = await self.get_user_token(context) @@ -126,6 +123,15 @@ async def begin_flow(self, context: TurnContext) -> FlowResponse: flow_state=self.flow_state, token_response=token_response ) + + self.flow_state = FlowState( + flow_id=self.__abs_oauth_connection_name, + channel_id=context.activity.channel_id, + user_id=context.activity.from_property.id, + abs_oauth_connection_name=self.__abs_oauth_connection_name, + tag=FlowStateTag.BEGIN, + expires_at=datetime.now().timestamp() + 60000, # 60 seconds + ) token_exchange_state = TokenExchangeState( connection_name=self.__abs_oauth_connection_name, @@ -185,10 +191,10 @@ async def continue_flow(self, context: TurnContext) -> FlowResponse: elif continue_flow_activity.flow_error_tag == ActivityTypes.invoke and continue_flow_activity.name == "signin/tokenExchange": token_response = await self.__continue_from_invoke_token_exchange(context) else: - pass + raise ValueError("Unknown activity type") if not token_response and flow_error_tag == FlowErrorTag.NONE: - flow_error_tag = FlowErrorTag.UNKNOWN + flow_error_tag = FlowErrorTag.OTHER if flow_error_tag != FlowErrorTag.NONE: self.__use_attempt() diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py index 30d35a9d..27147e0e 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -50,22 +50,22 @@ def __init__( def base_key(self) -> str: return self.__base_key - def key(self, id: str) -> str: + def key(self, flow_id: str) -> str: """Creates a storage key for a specific sign-in handler.""" - return f"{self.__base_key}{id}" + return f"{self.__base_key}{flow_id}" - async def read(self, auth_handler_id: str) -> Optional[FlowState]: + async def read(self, flow_id: str) -> Optional[FlowState]: """Reads the flow state for a specific authentication handler.""" - key: str = self.key(auth_handler_id) + key: str = self.key(flow_id) data = await self.__storage.read([key], FlowState) - return FlowState.validate(data.get(key)) # robrandao: TODO -> verify contract + return FlowState.model_validate(data.get(key)) # robrandao: TODO -> verify contract async def write(self, value: FlowState) -> None: """Saves the flow state for a specific authentication handler.""" - key: str = self.key(value.auth_handler_id) + key: str = self.key(value.flow_id) await self.__storage.write({key: value}) - async def delete(self, auth_handler_id: str) -> None: + async def delete(self, flow_id: str) -> None: """Deletes the flow state for a specific authentication handler.""" - key: str = self.key(auth_handler_id) + key: str = self.key(flow_id) await self.__storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py index 81bf13f1..9b6acfc0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py @@ -19,10 +19,11 @@ class FlowErrorTag(Enum): NONE = "none" MAGIC_FORMAT = "magic_format" MAGIC_CODE = "magic_code" + OTHER = "OTHER" class FlowState(BaseModel, StoreItem): - auth_handler_id: str = "" + flow_id: str = "" flow_started: bool = False user_token: str = "" expires_at: float = 0 diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py index dd1d5fac..0c2a593b 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py @@ -97,7 +97,7 @@ async def test_write(self, mocker, turn_context, auth_handler_id, key): storage.write.return_value = None client = FlowStorageClient(turn_context, storage) flow_state = mocker.Mock(spec=FlowState) - flow_state.auth_handler_id = auth_handler_id + flow_state.flow_id = auth_handler_id await client.write(flow_state) storage.write.assert_called_once_with({ key: flow_state }) @@ -119,8 +119,8 @@ async def test_delete(self, mocker, turn_context, auth_handler_id, key): @pytest.mark.asyncio async def test_integration_with_memory_storage(self, turn_context): - flow_state_alpha = FlowState(auth_handler_id="handler", flow_started=True) - flow_state_beta = FlowState(auth_handler_id="auth_handler", flow_started=True, user_token="token") + flow_state_alpha = FlowState(flow_id="handler", flow_started=True) + flow_state_beta = FlowState(flow_id="auth_handler", flow_started=True, user_token="token") storage = MemoryStorage({ "some_data": MockStoreItem({"value": "test"}), @@ -149,16 +149,16 @@ async def delete_both(*args, **kwargs): client = FlowStorageClient(turn_context, storage) - new_flow_state_alpha = FlowState(auth_handler_id="handler") - flow_state_chi = FlowState(auth_handler_id="chi") + new_flow_state_alpha = FlowState(flow_id="handler") + flow_state_chi = FlowState(flow_id="chi") await client.write(new_flow_state_alpha) await client.write(flow_state_chi) - await baseline.write({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) - await baseline.write({"auth/__channel_id/__user_id/chi": flow_state_chi.copy()}) + await baseline.write({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.model_copy()}) + await baseline.write({"auth/__channel_id/__user_id/chi": flow_state_chi.model_copy()}) - await write_both({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.copy()}) - await write_both({"auth/__channel_id/__user_id/auth_handler": flow_state_beta.copy()}) + await write_both({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.model_copy()}) + await write_both({"auth/__channel_id/__user_id/auth_handler": flow_state_beta.model_copy()}) await write_both({"other_data": MockStoreItem({"value": "more"})}) await delete_both(["some_data"]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py index 140beb0d..0dc2af22 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py @@ -5,7 +5,9 @@ from microsoft.agents.activity import ( ActivityTypes, TokenResponse, - SignInResource + SignInResource, + TokenExchangeState, + ConversationReference ) from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow @@ -23,6 +25,15 @@ def turn_context(self, mocker): context = mocker.Mock() context.activity.channel_id = "__channel_id" context.activity.from_property.id = "__user_id" + context.adapter.AGENT_IDENTITY_KEY = "__agent_id" + context.activity.relates_to = None + context.activity.get_conversation_reference = mocker.Mock() + context.activity.get_conversation_reference.return_value = mocker.Mock(spec=ConversationReference) + data = mocker.Mock() + data.claims = {"aud": "__app_id"} + context.turn_state = { + "__agent_id": data + } return context def test_init_no_state(self, mocker, turn_context): @@ -90,20 +101,26 @@ async def test_begin_flow_easy_case(self, mocker, turn_context): flow = AuthFlow( abs_oauth_connection_name="test_connection", user_token_client=user_token_client, + flow_state=FlowState( + tag=FlowStateTag.COMPLETE, + user_token="test_token", # robrandao: TODO -> what are all these fields for? + expires_at=datetime.now().timestamp() + 10000, + attempts_remaining=2 + ) ) response = await flow.begin_flow(turn_context) # verify flow_state flow_state = flow.flow_state assert flow_state.tag == FlowStateTag.COMPLETE - assert flow_state.token == "test_token" - assert flow_state.flow_started is False # robrandao: TODO? + assert flow_state.user_token == "test_token" + # assert flow_state.flow_started is False # robrandao: TODO? # verify FlowResponse assert response.flow_state == flow_state assert response.sign_in_resource is None # No sign-in resource in this case assert response.flow_error_tag == FlowErrorTag.NONE - assert response.token_response == "test_token" + assert response.token_response.token == "test_token" user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", @@ -204,7 +221,7 @@ async def test_continue_flow_message(self, mocker, turn_context, active_flow_sta turn_context.activity.text = "magic-message" user_token_client = mocker.Mock() user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock(return_value=dummy_sign_in_resource) + user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock(return_value=None) # test flow = AuthFlow( @@ -316,8 +333,8 @@ async def test_continue_flow_invoke_verify_state_no_token(self, mocker, turn_con if active_flow_state.attempts_remaining == 1: assert flow_response.flow_state.tag == FlowStateTag.FAILURE else: - assert flow_response.flow_state.tag == FlowStateTag.CONTINUE - assert flow_response.flow_error_tag == FlowErrorTag.UNKNOWN + assert flow_response.flow_state.tag == FlowStateTag.OTHER + assert flow_response.flow_error_tag == FlowErrorTag.OTHER user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", @@ -378,7 +395,7 @@ async def test_continue_flow_invoke_token_exchange_no_token(self, mocker, turn_c assert flow_response.flow_state.tag == FlowStateTag.FAILURE else: assert flow_response.flow_state.tag == FlowStateTag.CONTINUE - assert flow_response.flow_error_tag == FlowErrorTag.UNKNOWN + assert flow_response.flow_error_tag == FlowErrorTag.OTHER user_token_client.user_token.get_token.assert_called_once_with( user_id="__user_id", connection_name="test_connection", From 6db92176ee1abfb3b8f3b6893e72a40ad45c4c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Wed, 20 Aug 2025 15:44:01 -0700 Subject: [PATCH 10/32] Finishing flow storage client and auth flow tests --- .../hosting/core/app/agent_application.py | 25 +- .../hosting/core/app/oauth/auth_flow.py | 192 +++--- .../hosting/core/app/oauth/authorization.py | 88 ++- .../core/app/oauth/flow_storage_client.py | 16 +- .../agents/hosting/core/app/oauth/models.py | 10 +- .../app/oauth/tests/__authorization_test.py | 340 ++++++++++ .../app/oauth/tests/authorization_test.py | 331 ---------- .../hosting/core/app/oauth/tests/conftest.py | 0 .../core/app/oauth/tests/oauth_flow_test.py | 408 ------------ .../agents/hosting/core/app/oauth/utils.py | 10 + .../tests/flow_storage_client_test.py | 0 .../core/app/oauth => }/tests/models_test.py | 0 .../tests/test_auth_flow.py | 402 ++++++++++++ .../tests/test_authorization.py | 591 ++++++++++++------ .../tests/tools/mock_user_token_client.py | 89 +++ .../tests/tools/oauth_test_env.py | 142 +++++ .../tests/tools/testing_adapter.py | 61 +- .../tests/tools/testing_authorization.py | 1 + 18 files changed, 1587 insertions(+), 1119 deletions(-) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/__authorization_test.py delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/conftest.py delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py rename libraries/microsoft-agents-hosting-core/{microsoft/agents/hosting/core/app/oauth => }/tests/flow_storage_client_test.py (100%) rename libraries/microsoft-agents-hosting-core/{microsoft/agents/hosting/core/app/oauth => }/tests/models_test.py (100%) create mode 100644 libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py create mode 100644 libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py create mode 100644 libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index d7f77d0c..6dcf6449 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -33,8 +33,12 @@ MessageReactionTypes, MessageUpdateTypes, InvokeResponse, + OAuthCard, + Attachment, + CardAction ) +from .. import CardFactory, MessageFactory from .app_error import ApplicationError from .app_options import ApplicationOptions @@ -42,7 +46,14 @@ from .route import Route, RouteHandler from .state import TurnState from ..channel_service_adapter import ChannelServiceAdapter -from .oauth import Authorization, SignInState, FlowResponse, FlowStateTag +from .oauth import ( + Authorization, + FlowResponse, + FlowState, + FlowStateTag, + FlowErrorTag +) +from .typing_indicator import TypingIndicator logger = logging.getLogger(__name__) @@ -597,7 +608,7 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR in_flow_activity = flow_response.in_flow_activity if in_flow_activity: - context.send_activity(in_flow_activity) + await context.send_activity(in_flow_activity) if flow_state.tag == FlowStateTag.BEGIN: # Create the OAuth card @@ -642,9 +653,9 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR logger.warning("Sign-in flow failed for unknown reasons.") await context.send_activity("Sign-in failed. Please try again.") - async def _on_turn_auth_intercept(self, context: TurnContext, turn_state) -> bool: + async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: - prev_flow_state = self._auth.get_active_flow_state(context) + prev_flow_state = await self._auth.get_active_flow_state(context) if self._auth and prev_flow_state: logger.debug("Sign-in flow is active for context: %s", context.activity.id) @@ -653,7 +664,7 @@ async def _on_turn_auth_intercept(self, context: TurnContext, turn_state) -> boo context, turn_state, prev_flow_state.handler_id ) - self._handle_flow_response(flow_response) + await self._handle_flow_response(context, flow_response) new_flow_state: FlowState = flow_response.flow_state token_response: TokenResponse = new_flow_state.token_response @@ -688,7 +699,7 @@ async def _on_turn(self, context: TurnContext): logger.debug("Initializing turn state") turn_state = await self._initialize_state(context) - if self._on_turn_auth_intercept(context): + if await self._on_turn_auth_intercept(context, turn_state): return logger.debug("Running before turn middleware") @@ -813,7 +824,7 @@ async def _on_activity(self, context: TurnContext, state: StateT): flow_response: FlowResponse = await self._auth.begin_or_continue_flow( context, state, auth_handler_id ) - self._handle_flow_response(context, flow_response.in_flow_activity) + await self._handle_flow_response(context, flow_response.in_flow_activity) sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE if not sign_in_complete: break diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py index 981b04a0..3115b384 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -11,21 +11,19 @@ from microsoft.agents.hosting.core.connector.client import UserTokenClient from microsoft.agents.activity import ( + Activity, ActivityTypes, TokenExchangeState, TokenResponse, ) -from microsoft.agents.activity import ( - TurnContextProtocol as TurnContext, -) from microsoft.agents.hosting.core.storage import StoreItem, Storage from pydantic import BaseModel, PositiveInt from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag +from .utils import raise_if_empty_or_None logger = logging.getLogger(__name__) - class AuthFlow: """ Manages the OAuth flow. @@ -42,9 +40,8 @@ class AuthFlow: def __init__( self, - abs_oauth_connection_name: str, + flow_state: FlowState, user_token_client: UserTokenClient, - flow_state: FlowState = None, **kwargs ): """ @@ -55,141 +52,150 @@ def __init__( flow_state: """ - if not abs_oauth_connection_name: - raise ValueError( - "OAuthFlow.__init__: abs_oauth_connection_name required." - ) - if not user_token_client: - raise ValueError( - "OAuthFlow.__init__: user_token_client required." - ) + raise_if_empty_or_None( + self.__init__.__name__, + flow_state=flow_state, + user_token_client=user_token_client + ) + + if (not flow_state.abs_oauth_connection_name or + not flow_state.ms_app_id or + not flow_state.channel_id or + not flow_state.user_id): + raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, abs_oauth_connection_name defined") - if not flow_state: - self.flow_state = FlowState() - else: - self.flow_state = flow_state.model_copy() + self.__flow_state = flow_state.model_copy() + self.__abs_oauth_connection_name = self.__flow_state.abs_oauth_connection_name + self.__ms_app_id = self.__flow_state.ms_app_id + self.__channel_id = self.__flow_state.channel_id + self.__user_id = self.__flow_state.user_id - self.__abs_oauth_connection_name = abs_oauth_connection_name self.__user_token_client = user_token_client + self.__flow_duration = kwargs.get("flow_duration", 60000) # defaults to 60 seconds + self.__max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts + + @property + def flow_state(self) -> FlowState: + return self.__flow_state.model_copy() + # async def __initialize_token_client(self, context: TurnContext) -> None: # # robrandao: TODO is this safe # # use cached value later # self.__user_token_client = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) - - def __get_ids_or_raise(self, context: TurnContext) -> TokenResponse: - if ( - not context.activity.channel_id or - not context.activity.from_property or - not context.activity.from_property.id - ): - raise ValueError("User ID or Channel ID is not set in the activity.") + + async def get_user_token(self, magic_code: str = None) -> TokenResponse: + """Get the user token based on the context. - return context.activity.channel_id, context.activity.from_property.id + Args: + magic_code (str, optional): Defaults to None. The magic code for user authentication. + + Returns: + TokenResponse + The user token response. - async def __get_user_token(self, context: TurnContext, magic_code=None) -> TokenResponse: - channel_id, from_id = self.__get_ids_or_raise(context) + Notes + ----- + flow_state.user_token is updated with the latest token. - return await self.__user_token_client.user_token.get_token( - user_id=from_id, + """ + token_response: TokenResponse = await self.__user_token_client.user_token.get_token( + user_id=self.__user_id, connection_name=self.__abs_oauth_connection_name, - channel_id=channel_id, + channel_id=self.__channel_id, magic_code=magic_code ) + if token_response: + self.__flow_state.user_token = token_response.token + return token_response - async def get_user_token(self, context: TurnContext) -> TokenResponse: - return await self.__get_user_token(context) - - async def sign_out(self, context: TurnContext) -> None: - channel_id, from_id = self.__get_ids_or_raise(context) - - return await self.__user_token_client.user_token.sign_out( - user_id=from_id, + async def sign_out(self) -> None: + """Sign out the user.""" + await self.__user_token_client.user_token.sign_out( + user_id=self.__user_id, connection_name=self.__abs_oauth_connection_name, - channel_id=channel_id + channel_id=self.__channel_id ) + self.__flow_state.user_token = "" + self.__flow_state.tag = FlowStateTag.NOT_STARTED - async def __use_attempt(self) -> None: - if self.flow_state.attempts_remaining <= 0: - self.flow_state.flow_state_tag = FlowStateTag.FAILURE + def __use_attempt(self) -> None: + self.__flow_state.attempts_remaining -= 1 + if self.__flow_state.attempts_remaining <= 0: + self.__flow_state.tag = FlowStateTag.FAILURE - async def begin_flow(self, context: TurnContext) -> FlowResponse: + async def begin_flow(self, activity: Activity) -> FlowResponse: # init flow state - token_response = await self.get_user_token(context) + token_response = await self.get_user_token() if token_response: return FlowResponse( - flow_state=self.flow_state, + flow_state=self.__flow_state, token_response=token_response ) - self.flow_state = FlowState( - flow_id=self.__abs_oauth_connection_name, - channel_id=context.activity.channel_id, - user_id=context.activity.from_property.id, - abs_oauth_connection_name=self.__abs_oauth_connection_name, - tag=FlowStateTag.BEGIN, - expires_at=datetime.now().timestamp() + 60000, # 60 seconds - ) + self.__flow_state.tag = FlowStateTag.BEGIN + self.__flow_state.expires_at = datetime.now().timestamp() + self.__flow_duration + self.__flow_state.attempts_remaining = self.__max_attempts + self.__flow_state.user_token = "" token_exchange_state = TokenExchangeState( connection_name=self.__abs_oauth_connection_name, - conversation=context.activity.get_conversation_reference(), - relates_to=context.activity.relates_to, - ms_app_id=context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims["aud"] # robrandao: TODO + conversation=activity.get_conversation_reference(), + relates_to=activity.relates_to, + ms_app_id=self.__ms_app_id # robrandao: TODO ) sign_in_resource = await self.__user_token_client.agent_sign_in.get_sign_in_resource( state=token_exchange_state.get_encoded_state()) - return FlowResponse(flow_state=self.flow_state, sign_in_resource=sign_in_resource) + return FlowResponse(flow_state=self.__flow_state, sign_in_resource=sign_in_resource) - async def __continue_from_message(self, context: TurnContext) -> tuple[TokenResponse, FlowErrorTag]: - magic_code: str = context.activity.text + async def __continue_from_message(self, activity: Activity) -> tuple[TokenResponse, FlowErrorTag]: + magic_code: str = activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: - token_response: TokenResponse = await self.__get_user_token(context, magic_code) + token_response: TokenResponse = await self.get_user_token(magic_code) if token_response: return token_response, FlowErrorTag.NONE else: - return token_response, FlowErrorTag.MAGIC_CODE + return token_response, FlowErrorTag.MAGIC_CODE_INCORRECT else: return TokenResponse(), FlowErrorTag.MAGIC_FORMAT - async def __continue_from_invoke_verify_state(self, context: TurnContext) -> TokenResponse: - token_verify_state = context.activity.value + async def __continue_from_invoke_verify_state(self, activity: Activity) -> TokenResponse: + token_verify_state = activity.value magic_code: str = token_verify_state.get("state") - token_response: TokenResponse = await self.__get_user_token(context, magic_code) + token_response: TokenResponse = await self.get_user_token(magic_code) return token_response - async def __continue_from_invoke_token_exchange(self, context: TurnContext) -> TokenResponse: - channel_id, from_id = self.__get_ids_or_raise(context) - token_exchange_request = context.activity.value + async def __continue_from_invoke_token_exchange(self, activity: Activity) -> TokenResponse: + token_exchange_request = activity.value token_response = await self.__user_token_client.user_token.exchange_token( - user_id=from_id, + user_id=self.__user_id, connection_name=self.__abs_oauth_connection_name, - channel_id=channel_id, + channel_id=self.__channel_id, body=token_exchange_request ) - return token_response, FlowErrorTag.NONE + return token_response - async def continue_flow(self, context: TurnContext) -> FlowResponse: + async def continue_flow(self, activity: Activity) -> FlowResponse: logger.debug("Continuing auth flow...") - if not self.flow_state.is_active(): - self.flow_state.flow_state_tag = FlowStateTag.FAILURE - return FlowResponse(flow_state=self.flow_state) - - continue_flow_activity = context.activity + if not self.__flow_state.is_active(): + self.__flow_state.tag = FlowStateTag.FAILURE + return FlowResponse(flow_state=self.__flow_state) flow_error_tag = FlowErrorTag.NONE - if continue_flow_activity.type == ActivityTypes.message: - token_response, flow_error_tag = await self.__continue_from_message(context) - elif continue_flow_activity.type == ActivityTypes.invoke and continue_flow_activity.name == "signin/verifyState": - token_response = await self.__continue_from_invoke_verify_state(context) - elif continue_flow_activity.flow_error_tag == ActivityTypes.invoke and continue_flow_activity.name == "signin/tokenExchange": - token_response = await self.__continue_from_invoke_token_exchange(context) + if activity.type == ActivityTypes.message: + token_response, flow_error_tag = await self.__continue_from_message(activity) + elif (activity.type == ActivityTypes.invoke + and activity.name == "signin/verifyState"): + token_response = await self.__continue_from_invoke_verify_state(activity) + elif (activity.type == ActivityTypes.invoke + and activity.name == "signin/tokenExchange"): + token_response = await self.__continue_from_invoke_token_exchange(activity) else: raise ValueError("Unknown activity type") @@ -197,16 +203,22 @@ async def continue_flow(self, context: TurnContext) -> FlowResponse: flow_error_tag = FlowErrorTag.OTHER if flow_error_tag != FlowErrorTag.NONE: + self.__flow_state.tag = FlowStateTag.CONTINUE self.__use_attempt() + else: + self.__flow_state.tag = FlowStateTag.COMPLETE + self.__flow_state.expires_at = datetime.now().timestamp() + self.__flow_duration + self.__flow_state.user_token = token_response.token + return FlowResponse( - flow_state=self.flow_state, + flow_state=self.__flow_state.model_copy(), flow_error_tag=flow_error_tag, token_response=token_response ) - async def begin_or_continue_flow(self, context: TurnContext) -> FlowResponse: - if self.flow_state.is_active(): - return await self.continue_flow(context) + async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: + if self.__flow_state.is_active(): + return await self.continue_flow(activity) else: - return await self.begin_flow(context) \ No newline at end of file + return await self.begin_flow(activity) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index e8483c78..0612bdb1 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -62,14 +62,14 @@ def __init__( # user_state = UserState(storage) - self._storage = storage - self._connection_manager = connection_manager + self.__storage = storage + self.__connection_manager = connection_manager auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( "USERAUTHORIZATION", {} ) - # self._auto_signin = ( + # self.__auto_signin = ( # auto_signin # if auto_signin is not None # else auth_configuration.get("AUTOSIGNIN", False) @@ -84,16 +84,16 @@ def __init__( for handler_name, config in handlers_config.items() } - self._auth_handlers = auth_handlers or {} - self._sign_in_handler: Optional[ + self.__auth_handlers = auth_handlers or {} + self.__sign_in_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] ] = None - self._sign_in_failed_handler: Optional[ + self.__sign_in_failed_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] ] = None # # Configure each auth handler - # for auth_handler in self._auth_handlers.values(): + # for auth_handler in self.__auth_handlers.values(): # # Create OAuth flow with configuration # messages_config = {} # if auth_handler.title: @@ -108,14 +108,34 @@ def __init__( # messages_configuration=messages_config if messages_config else None, # ) - async def _load_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlow: + def __check_for_ids(self, context: TurnContext): + if ( + not context.activity.channel_id or + not context.activity.from_property or + not context.activity.from_property.id + ): + raise ValueError("Channel ID and User ID are required") + + async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> tuple[AuthFlow, FlowStorageClient, FlowState]: user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) # robrandao: TODO auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) + + self.__check_for_ids(context) + channel_id = context.activity.channel_id + user_id = context.activity.from_property.id - flow_storage_client = FlowStorageClient(context, self._storage) + flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) flow_state: FlowState = await flow_storage_client.read(auth_handler_id) - flow = AuthFlow(auth_handler.abs_oauth_connection_name, user_token_client, flow_state) + if not flow_state: + flow_state = FlowState( + channel_id=channel_id, + user_id=user_id, + auth_handler_id=auth_handler_id, + abs_oauth_connection_name=auth_handler.abs_oauth_connection_name + ) + + flow = AuthFlow(flow_state, user_token_client) return flow, flow_storage_client, flow_state @asynccontextmanager @@ -133,7 +153,7 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str = "", reado if not context or not auth_handler_id: raise ValueError("context and auth_handler_id are required") - flow, flow_storage_client, init_flow_state = self._load_flow(context, auth_handler_id) + flow, flow_storage_client, init_flow_state = self.__load_flow(context, auth_handler_id) yield flow if not readonly and flow.flow_state != init_flow_state: @@ -173,10 +193,10 @@ async def exchange_token( The token response from the OAuth provider. """ async with self.open_flow(context, auth_handler_id) as flow: - token_response = await flow.get_user_token(context) + token_response = await flow.get_user_token() - if token_response and self._is_exchangeable(token_response.token): - return await self._handle_obo(token_response.token, scopes, auth_handler_id) + if token_response and self.__is_exchangeable(token_response.token): + return await self.__handle_obo(token_response.token, scopes, auth_handler_id) return TokenResponse() @@ -187,12 +207,12 @@ async def exchange_token( # token_response = await auth_handler.flow.get_user_token(context) - # if self._is_exchangeable(token_response.token if token_response else None): - # return await self._handle_obo(token_response.token, scopes, auth_handler_id) + # if self.__is_exchangeable(token_response.token if token_response else None): + # return await self.__handle_obo(token_response.token, scopes, auth_handler_id) # return token_response - def _is_exchangeable(self, token: Optional[str]) -> bool: + def __is_exchangeable(self, token: Optional[str]) -> bool: """ Checks if a token is exchangeable (has api:// audience). @@ -214,7 +234,7 @@ def _is_exchangeable(self, token: Optional[str]) -> bool: logger.exception("Failed to decode token to check audience") return False - async def _handle_obo( + async def __handle_obo( self, token: str, scopes: list[str], handler_id: str = None ) -> TokenResponse: """ @@ -228,7 +248,7 @@ async def _handle_obo( Returns: The new token response. """ - if not self._connection_manager: + if not self.__connection_manager: logger.error("Connection manager is not configured", stack_info=True) raise ValueError("Connection manager is not configured") @@ -239,7 +259,7 @@ async def _handle_obo( # Use the flow's OBO method to exchange the token token_provider: AccessTokenProviderBase = ( - self._connection_manager.get_connection(auth_handler.obo_connection_name) + self.__connection_manager.get_connection(auth_handler.obo_connection_name) ) logger.info("Attempting to exchange token on behalf of user") token = await token_provider.aquire_token_on_behalf_of( @@ -252,8 +272,8 @@ async def _handle_obo( ) async def get_active_flow_state(self, context: TurnContext, turn_state: TurnState = None) -> Optional[FlowState]: - flow_storage_client = FlowStorageClient(context, self._storage) - for auth_handler_id in self._auth_handlers.keys(): + flow_storage_client = FlowStorageClient(context, self.__storage) + for auth_handler_id in self.__auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) if flow_state.is_active(): return flow_state @@ -285,9 +305,9 @@ async def begin_or_continue_flow( flow_state: FlowState = flow_response.flow_state if flow_state.tag == FlowStateTag.COMPLETE: - self._sign_in_success_handler(context, turn_state, flow_state.handler.id) + self.__sign_in_success_handler(context, turn_state, flow_state.handler.id) elif flow_state.tag == FlowStateTag.FAILURE: - self._sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) + self.__sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) return flow_response @@ -302,24 +322,24 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: The resolved auth handler. """ if auth_handler_id: - if auth_handler_id not in self._auth_handlers: + if auth_handler_id not in self.__auth_handlers: logger.error(f"Auth handler '{auth_handler_id}' not found") raise ValueError(f"Auth handler '{auth_handler_id}' not found") - return self._auth_handlers[auth_handler_id] + return self.__auth_handlers[auth_handler_id] # Return the first handler if no ID specified - return next(iter(self._auth_handlers.values)) + return next(iter(self.__auth_handlers.values)) - async def _sign_out( + async def __sign_out( self, context: TurnContext, auth_handler_ids: Iterable[str] = None, ) -> None: for auth_handler_id in auth_handler_ids: - flow, flow_storage_client, initial_flow_state = self._load_flow(context, auth_handler_id) + flow, flow_storage_client, initial_flow_state = self.__load_flow(context, auth_handler_id) if initial_flow_state: logger.info(f"Signing out from handler: {auth_handler_id}") - await flow.sign_out(context) + await flow.sign_out() flow_storage_client.delete(auth_handler_id) async def sign_out( @@ -338,9 +358,9 @@ async def sign_out( auth_handler_id: Optional ID of the auth handler to use for sign out. """ if auth_handler_id: - self._sign_out(context, [auth_handler_id]) + self.__sign_out(context, [auth_handler_id]) else: - self._sign_out(context, self._auth_handlers.keys()) + self.__sign_out(context, self.__auth_handlers.keys()) def on_sign_in_success( self, @@ -352,7 +372,7 @@ def on_sign_in_success( Args: handler: The handler function to call on successful sign-in. """ - self._sign_in_success_handler = handler + self.__sign_in_success_handler = handler def on_sign_in_failure( self, @@ -363,4 +383,4 @@ def on_sign_in_failure( Args: handler: The handler function to call on sign-in failure. """ - self._sign_in_failure_handler = handler \ No newline at end of file + self.__sign_in_failure_handler = handler \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py index 27147e0e..d61e4227 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -21,7 +21,8 @@ class FlowStorageClient: def __init__( self, - context: TurnContext, + channel_id: str, + user_id: str, storage: Storage ): """ @@ -32,16 +33,8 @@ def __init__( storage: The Storage instance used to persist flow state data. """ - if ( - not context.activity - or not context.activity.channel_id - or not context.activity.from_property - or not context.activity.from_property.id - ): - raise ValueError("context.activity -> channel_id and from.id must be set.") - - channel_id = context.activity.channel_id - user_id = context.activity.from_property.id + if not user_id or not channel_id: + raise ValueError("FlowStorageClient.__init__(): channel_id and user_id must be set.") self.__base_key = f"auth/{channel_id}/{user_id}/" self.__storage = storage @@ -50,6 +43,7 @@ def __init__( def base_key(self) -> str: return self.__base_key + @staticmethod def key(self, flow_id: str) -> str: """Creates a storage key for a specific sign-in handler.""" return f"{self.__base_key}{flow_id}" diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py index 9b6acfc0..08c36b29 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py @@ -18,18 +18,20 @@ class FlowStateTag(Enum): class FlowErrorTag(Enum): NONE = "none" MAGIC_FORMAT = "magic_format" - MAGIC_CODE = "magic_code" + MAGIC_CODE_INCORRECT = "magic_code_incorrect" OTHER = "OTHER" class FlowState(BaseModel, StoreItem): - flow_id: str = "" - flow_started: bool = False + flow_id: str = "" # robrandao: TODO user_token: str = "" expires_at: float = 0 + channel_id: str = "" + user_id: str = "" + ms_app_id: str = "" abs_oauth_connection_name: Optional[str] = None continuation_activity: Optional[Activity] = None - attempts_remaining: int = 3 + attempts_remaining: int = 0 tag: FlowStateTag = FlowStateTag.NOT_STARTED def refresh(self) -> None: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/__authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/__authorization_test.py new file mode 100644 index 00000000..63e68f3e --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/__authorization_test.py @@ -0,0 +1,340 @@ +import datetime + +import pytest +from pytest_lazyfixture import lazy_fixture + +from microsoft.agents.activity import ( + TokenResponse +) +from microsoft.agents.hosting.core import ( + Authorization, + MemoryStorage, + FlowStorageClient, + FlowState, + FlowErrorTag, + FlowStateTag, + FlowResponse +) +from microsoft.agents.hosting.core.app.oauth.auth_flow import ( + AuthFlow +) +from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline + +def mock_flow(mocker, flow_states: list[FlowState]): + flow = mocker.Mock(spec=AuthFlow) + flow.begin_or_continue_flow = mocker.AsyncMock( + side_effect=flow_states + ) + return flow + +STORAGE_SAMPLE_DICT = { + "user_id": "123", + "session_id": "abc", + "auth/channel_id/user_id/expired": FlowState( + id="expired", + expires=expired_time, + attempts_remaining=1, + tag=FlowStateTag.CONTINUE + ), + "auth/teams_id/Bob/no_retries": FlowState( + id="no_retries", + expires=valid_time, + attempts_remaining=0, + tag=FlowStateTag.FAILURE + ), + "auth/channel/Alice/begin": FlowState( + id="begin", + expired=valid_time, + attempts_remaining=3, + tag=FlowStateTag.BEGIN + ), + "auth/channel/Alice/continue": FlowState( + id="continue", + expires=valid_time, + attempts_remaining=2 + tag=FlowStateTag.CONTINUE + ), + "auth/channel/Alice/expired_and_retries": FlowState( + id="expired_and_retries" + expires=expired_time, + attempts_remaining=0 + tag=FlowStateTag.FAILURE + ), + "auth/channel/Alice/not_started": FlowState( + id="not_started", + tag=FlowStateTag.NOT_STARTED + ) +} + +class TestAuthorization: + + def build_context(self, mocker, channel_id, from_property_id): + turn_context = mocker.Mock() + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = from_property_id + return turn_context + + @pytest.fixture + + @pytest.fixture + def context(self, mocker): + return self.build_context(mocker, "__channel_id", "__user_id") + + @pytest.fixture + def valid_time(self): + return datetime.datetime.now() + 10000 + + @pytest.fixture + def expired_time(self): + return datetime.datetime.now() + + @pytest.fixture + def m_storage(self, mocker): + return mocker.Mock(spec=MemoryStorage) + + @pytest.fixture + def m_connection_manager(self, mocker): + return mocker.Mock(spec=ConnectionManager) + + @pytest.fixture + def auth_handler_ids(self): + return ["expired", "no_retries", "begin", "continue", "expired_and_retries", "not_started"] + + @pytest.fixture + def auth_handlers(self, mocker, auth_handler_ids): + return { + auth_handler_id: create_test_auth_handler(f"test-{auth_handler_id}") for auth_handler_id in auth_handler_ids + } + + @pytest.fixture + def storage(self, valid_time, expired_time): + return MemoryStorage(STORAGE_SAMPLE_DICT) + + @pytest.fixture + def connection_manager(self): + pass + + @pytest.fixture + def auth_handlers(self): + pass + + @pytest.fixture + def auth(self, storage, connection_manager, auth_handlers): + return Authorization( + storage, + connection_manager, + auth_handlers, + auto_signin=True + ) + + @pytest.fixture + def storage(self, mocker): + return MemoryStorage({ + + }) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth, context, auth_handler_id", + [ + ("auth", lazy_fixture("context"), ""), + ("auth", None, "handler"), + ("auth", None, "") + ("auth", lazy_fixture("context", "missing_handler")) + ], + indirect=["auth"] + ) + async def test_open_flow_value_error(self, auth, context, auth_handler_id): + with pytest.raises(ValueError): + async with auth.open_flow(context, auth_handler_id): + pass + + # async def test_open_flow_storage_readonly_storage_access(self, mocker, context, m_storage, m_connection_manager, m_auth_handlers): + # # setup + # m_storage.read.return_value = FlowState() + # auth = Authorization( + # m_storage, + # m_connection_manager, + # m_auth_handlers + # ) + + # # code + # async with auth.open_flow(context, "handler", readonly=True) as flow: + # actual_init_flow_state = flow.flow_state + + # # verify + # assert actual_init_flow_state is m_storage.read.return_value + # assert not m_storage.write.called + # assert not m_storage.delete.called + + # async def test_open_flow_storage_unchanged_not_readonly_storage_access(self, context, m_storage, m_connection_manager, m_auth_handlers): + # # setup + # m_storage.read.return_value = FlowState() + # auth = Authorization( + # m_storage, + # m_connection_manager, + # m_auth_handlers + # ) + + # # code + # async with auth.open_flow(context, "handler", readonly=False) as flow: + # # if no change is made to the flow state, then storage should not be updated + # actual_init_flow_state = flow.flow_state + + # # verify + # assert actual_init_flow_state is m_storage.read.return_value + # assert not m_storage.write.called + # assert not m_storage.delete.called + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", + [ + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), + ] + ) + async def test_open_flow_readonly_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): + # setup + storage = MemoryStorage(STORAGE_SAMPLE_DICT) + baseline = StorageBaseline(STORAGE_SAMPLE_DICT) + auth = Authorization( + storage, + connection_manager, + auth_handlers + ) + context = self.build_context(mocker, channel_id, from_property_id) + storage_client = FlowStorageClient(context, storage) + key = storage_client.key(auth_handler_id) + expected_init_flow_state = storage.read(key, FlowState) + + # code + async with auth.open_flow(context, "handler", readonly=True) as flow: + actual_init_flow_state = flow.flow_state.copy() + flow.flow_state.id = "garbage" + flow.flow_state.tag = FlowStateTag.FAILURE + flow.flow_state.expires = 0 + flow.flow_state.attempts_remaining = -1 + actual_final_flow_state = await storage.read([key], FlowState)[key] + + # verify + expected_final_flow_state = baseline.read(key, FlowState) + assert actual_init_flow_state == expected_init_flow_state + assert actual_final_flow_state == expected_final_flow_state + assert await baseline.equals(storage) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", + [ + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), + (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), + ] + ) + async def test_open_flow_storage_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): + # setup + storage = MemoryStorage(STORAGE_SAMPLE_DICT) + baseline = StorageBaseline(STORAGE_SAMPLE_DICT) + auth = Authorization( + storage, + connection_manager, + auth_handlers + ) + context = self.build_context(mocker, channel_id, from_property_id) + storage_client = FlowStorageClient(context, storage) + key = storage_client.key(auth_handler_id) + expected_init_flow_state = storage.read(key, FlowState) + + # code + async with auth.open_flow(context, "handler") as flow: + actual_init_flow_state = flow.flow_state.copy() + flow.flow_state.id = "garbage" + flow.flow_state.tag = FlowStateTag.FAILURE + flow.flow_state.expires = 0 + flow.flow_state.attempts_remaining = -1 + + # verify + baseline.write({ + "auth/channel/Alice/continue": flow.flow_state + }) + expected_final_flow_state = baseline.read(key, FlowState) + assert await baseline.equals(storage) + assert actual_init_flow_state == expected_init_flow_state + assert flow.flow_state == expected_final_flow_state + + @pytest.mark.asyncio + async def test_get_token(self, mocker, m_storage): + m_storage.read.return_value = FlowState( + id="auth_handler", + tag=FlowStateTag.ACTIVE, + expires=3600, + attempts_remaining=3 + ) + expected = TokenResponse( + access_token="access_token", + refresh_token="refresh_token", + expires_in=3600 + ) + mock_flow = mocker.AsyncMock() + mock_flow.get_user_token.return_value = expected + mocker.patch.object("OAuthFlow", "get_token", return_value=expected) + mocker.patch.object("OAuthFlow", "__init__", return_value=mock_flow) + + assert await auth.get_token("auth_handler") is expected + assert mock_flow.get_user_token.called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth, context, auth_handler_id", + [ + (lazy_fixture("auth"), lazy_fixture("context"), "missing-handler"), + (lazy_fixture("auth"), lazy_fixture("context"), ""), + (lazy_fixture("auth"), None, "handler") + ] + ) + async def test_get_token_error(self, auth, context, auth_handler_id): + with pytest.raises(ValueError): + await auth.get_token(context, auth_handler_id) + + @pytest.fixture + def valid_token_response(self): + return TokenResponse( + connection_name="connection", + token="token" + ) + + @pytest.fixture + def invalid_exchange_token(self): + token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") + return TokenResponse( + connection_name="connection" + token=token + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize + async def test_exchange_token(self, mocker, auth): + + mocker.patch.object("OAuthFlow", + get_user_token=mocker.AsyncMock(return_value=TokenResponse( + access_token="access_token", + refresh_token="refresh_token", + expires_in=3600 + )) + ) + + + + + + @pytest.mark.asyncio + async def test_exchange_token(self): + pass \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py deleted file mode 100644 index 110454bc..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/authorization_test.py +++ /dev/null @@ -1,331 +0,0 @@ -# import datetime - -# import pytest -# from pytest_lazyfixture import lazy_fixture - -# from microsoft.agents.activity import ( -# TokenResponse -# ) -# from microsoft.agents.hosting.core import ( -# Authorization, -# MemoryStorage, -# FlowStorageClient, -# FlowState, -# FlowErrorTag, -# FlowStateTag, -# FlowResponse -# ) -# from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline -# from tools.testing_authorization import ( - -# ) - -# STORAGE_SAMPLE_DICT = { -# "user_id": "123", -# "session_id": "abc", -# "auth/channel_id/user_id/expired": FlowState( -# id="expired", -# expires=expired_time, -# attempts_remaining=1, -# tag=FlowStateTag.CONTINUE -# ), -# "auth/teams_id/Bob/no_retries": FlowState( -# id="no_retries", -# expires=valid_time, -# attempts_remaining=0, -# tag=FlowStateTag.FAILURE -# ), -# "auth/channel/Alice/begin": FlowState( -# id="begin", -# expired=valid_time, -# attempts_remaining=3, -# tag=FlowStateTag.BEGIN -# ), -# "auth/channel/Alice/continue": FlowState( -# id="continue", -# expires=valid_time, -# attempts_remaining=2 -# tag=FlowStateTag.CONTINUE -# ), -# "auth/channel/Alice/expired_and_retries": FlowState( -# id="expired_and_retries" -# expires=expired_time, -# attempts_remaining=0 -# tag=FlowStateTag.FAILURE -# ), -# "auth/channel/Alice/not_started": FlowState( -# id="not_started", -# tag=FlowStateTag.NOT_STARTED -# ) -# } - -# class TestAuthorization: - -# def build_context(self, mocker, channel_id, from_property_id): -# turn_context = mocker.Mock() -# turn_context.activity.channel_id = channel_id -# turn_context.activity.from_property.id = from_property_id -# return turn_context - -# @pytest.fixture -# def context(self, mocker): -# return self.build_context(mocker, "__channel_id", "__user_id") - -# @pytest.fixture -# def valid_time(self): -# return datetime.datetime.now() + 10000 - -# @pytest.fixture -# def expired_time(self): -# return datetime.datetime.now() - -# @pytest.fixture -# def m_storage(self, mocker): -# return mocker.Mock(spec=MemoryStorage) - -# @pytest.fixture -# def m_connection_manager(self, mocker): -# return mocker.Mock(spec=ConnectionManager) - -# @pytest.fixture -# def auth_handler_ids(self): -# return ["expired", "no_retries", "begin", "continue", "expired_and_retries", "not_started"] - -# @pytest.fixture -# def auth_handlers(self, mocker, auth_handler_ids): -# return { -# auth_handler_id: create_test_auth_handler(f"test-{auth_handler_id}") for auth_handler_id in auth_handler_ids -# } - -# @pytest.fixture -# def storage(self, valid_time, expired_time): -# return MemoryStorage(STORAGE_SAMPLE_DICT) - -# @pytest.fixture -# def connection_manager(self): -# pass - -# @pytest.fixture -# def auth_handlers(self): -# pass - -# @pytest.fixture -# def auth(self, storage, connection_manager, auth_handlers): -# return Authorization( -# storage, -# connection_manager, -# auth_handlers, -# auto_signin=True -# ) - -# @pytest.fixture -# def storage(self, mocker): -# return MemoryStorage({ - -# }) - -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "auth, context, auth_handler_id", -# [ -# ("auth", lazy_fixture("context"), ""), -# ("auth", None, "handler"), -# ("auth", None, "") -# ("auth", lazy_fixture("context", "missing_handler")) -# ], -# indirect=["auth"] -# ) -# async def test_open_flow_value_error(self, auth, context, auth_handler_id): -# with pytest.raises(ValueError): -# async with auth.open_flow(context, auth_handler_id): -# pass - -# # async def test_open_flow_storage_readonly_storage_access(self, mocker, context, m_storage, m_connection_manager, m_auth_handlers): -# # # setup -# # m_storage.read.return_value = FlowState() -# # auth = Authorization( -# # m_storage, -# # m_connection_manager, -# # m_auth_handlers -# # ) - -# # # code -# # async with auth.open_flow(context, "handler", readonly=True) as flow: -# # actual_init_flow_state = flow.flow_state - -# # # verify -# # assert actual_init_flow_state is m_storage.read.return_value -# # assert not m_storage.write.called -# # assert not m_storage.delete.called - -# # async def test_open_flow_storage_unchanged_not_readonly_storage_access(self, context, m_storage, m_connection_manager, m_auth_handlers): -# # # setup -# # m_storage.read.return_value = FlowState() -# # auth = Authorization( -# # m_storage, -# # m_connection_manager, -# # m_auth_handlers -# # ) - -# # # code -# # async with auth.open_flow(context, "handler", readonly=False) as flow: -# # # if no change is made to the flow state, then storage should not be updated -# # actual_init_flow_state = flow.flow_state - -# # # verify -# # assert actual_init_flow_state is m_storage.read.return_value -# # assert not m_storage.write.called -# # assert not m_storage.delete.called - -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", -# [ -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), -# ] -# ) -# async def test_open_flow_readonly_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): -# # setup -# storage = MemoryStorage(STORAGE_SAMPLE_DICT) -# baseline = StorageBaseline(STORAGE_SAMPLE_DICT) -# auth = Authorization( -# storage, -# connection_manager, -# auth_handlers -# ) -# context = self.build_context(mocker, channel_id, from_property_id) -# storage_client = FlowStorageClient(context, storage) -# key = storage_client.key(auth_handler_id) -# expected_init_flow_state = storage.read(key, FlowState) - -# # code -# async with auth.open_flow(context, "handler", readonly=True) as flow: -# actual_init_flow_state = flow.flow_state.copy() -# flow.flow_state.id = "garbage" -# flow.flow_state.tag = FlowStateTag.FAILURE -# flow.flow_state.expires = 0 -# flow.flow_state.attempts_remaining = -1 -# actual_final_flow_state = await storage.read([key], FlowState)[key] - -# # verify -# expected_final_flow_state = baseline.read(key, FlowState) -# assert actual_init_flow_state == expected_init_flow_state -# assert actual_final_flow_state == expected_final_flow_state -# assert await baseline.equals(storage) - -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", -# [ -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), -# (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), -# ] -# ) -# async def test_open_flow_storage_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): -# # setup -# storage = MemoryStorage(STORAGE_SAMPLE_DICT) -# baseline = StorageBaseline(STORAGE_SAMPLE_DICT) -# auth = Authorization( -# storage, -# connection_manager, -# auth_handlers -# ) -# context = self.build_context(mocker, channel_id, from_property_id) -# storage_client = FlowStorageClient(context, storage) -# key = storage_client.key(auth_handler_id) -# expected_init_flow_state = storage.read(key, FlowState) - -# # code -# async with auth.open_flow(context, "handler") as flow: -# actual_init_flow_state = flow.flow_state.copy() -# flow.flow_state.id = "garbage" -# flow.flow_state.tag = FlowStateTag.FAILURE -# flow.flow_state.expires = 0 -# flow.flow_state.attempts_remaining = -1 - -# # verify -# baseline.write({ -# "auth/channel/Alice/continue": flow.flow_state -# }) -# expected_final_flow_state = baseline.read(key, FlowState) -# assert await baseline.equals(storage) -# assert actual_init_flow_state == expected_init_flow_state -# assert flow.flow_state == expected_final_flow_state - -# @pytest.mark.asyncio -# async def test_get_token(self, mocker, m_storage): -# m_storage.read.return_value = FlowState( -# id="auth_handler", -# tag=FlowStateTag.ACTIVE, -# expires=3600, -# attempts_remaining=3 -# ) -# expected = TokenResponse( -# access_token="access_token", -# refresh_token="refresh_token", -# expires_in=3600 -# ) -# mock_flow = mocker.AsyncMock() -# mock_flow.get_user_token.return_value = expected -# mocker.patch.object("OAuthFlow", "get_token", return_value=expected) -# mocker.patch.object("OAuthFlow", "__init__", return_value=mock_flow) - -# assert await auth.get_token("auth_handler") is expected -# assert mock_flow.get_user_token.called_once() - -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "auth, context, auth_handler_id", -# [ -# (lazy_fixture("auth"), lazy_fixture("context"), "missing-handler"), -# (lazy_fixture("auth"), lazy_fixture("context"), ""), -# (lazy_fixture("auth"), None, "handler") -# ] -# ) -# async def test_get_token_error(self, auth, context, auth_handler_id): -# with pytest.raises(ValueError): -# await auth.get_token(context, auth_handler_id) - -# @pytest.fixture -# def valid_token_response(self): -# return TokenResponse( -# connection_name="connection", -# token="token" -# ) - -# @pytest.fixture -# def invalid_exchange_token(self): -# token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") -# return TokenResponse( -# connection_name="connection" -# token=token -# ) - -# @pytest.mark.asyncio -# @pytest.mark.parametrize -# async def test_exchange_token(self, mocker, auth): - -# mocker.patch.object("OAuthFlow", -# get_user_token=mocker.AsyncMock(return_value=TokenResponse( -# access_token="access_token", -# refresh_token="refresh_token", -# expires_in=3600 -# )) -# ) - - - - - -# @pytest.mark.asyncio -# async def test_exchange_token(self): -# pass \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/conftest.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py deleted file mode 100644 index 0dc2af22..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/oauth_flow_test.py +++ /dev/null @@ -1,408 +0,0 @@ -from datetime import datetime - -import pytest - -from microsoft.agents.activity import ( - ActivityTypes, - TokenResponse, - SignInResource, - TokenExchangeState, - ConversationReference -) -from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow - -from microsoft.agents.hosting.core.app.oauth.models import ( - FlowErrorTag, - FlowState, - FlowStateTag, - FlowResponse -) - -class TestAuthFlow: - - @pytest.fixture - def turn_context(self, mocker): - context = mocker.Mock() - context.activity.channel_id = "__channel_id" - context.activity.from_property.id = "__user_id" - context.adapter.AGENT_IDENTITY_KEY = "__agent_id" - context.activity.relates_to = None - context.activity.get_conversation_reference = mocker.Mock() - context.activity.get_conversation_reference.return_value = mocker.Mock(spec=ConversationReference) - data = mocker.Mock() - data.claims = {"aud": "__app_id"} - context.turn_state = { - "__agent_id": data - } - return context - - def test_init_no_state(self, mocker, turn_context): - flow = AuthFlow(turn_context, mocker.Mock()) - assert flow.flow_state == FlowState() - - def test_init_with_state(self, mocker, turn_context): - flow_state = FlowState( - tag=FlowStateTag.CONTINUE, - attempts_remaining=1, - expires_at=datetime.now().timestamp() + 10000 - ) - flow = AuthFlow(turn_context, mocker.Mock(), flow_state=flow_state) - assert flow.flow_state == flow_state - - @pytest.mark.asyncio - async def test_get_user_token(self, mocker, turn_context): - # mock - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value="test_token") - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - ) - token = await flow.get_user_token(turn_context) - - # verify - assert token == "test_token" - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - magic_code=None - ) - - @pytest.mark.asyncio - async def test_sign_out(self, mocker, turn_context): - # mock - user_token_client = mocker.Mock() - user_token_client.user_token.sign_out = mocker.AsyncMock() - - # test - flow = AuthFlow( - abs_oauth_connection_name="connection", - user_token_client=user_token_client, - ) - await flow.sign_out(turn_context) - - # verify - user_token_client.user_token.sign_out.assert_called_once_with( - user_id="__user_id", - connection_name="connection", - channel_id="__channel_id" - ) - - @pytest.mark.asyncio - async def test_begin_flow_easy_case(self, mocker, turn_context): - # mock - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse(token="test_token")) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=FlowState( - tag=FlowStateTag.COMPLETE, - user_token="test_token", # robrandao: TODO -> what are all these fields for? - expires_at=datetime.now().timestamp() + 10000, - attempts_remaining=2 - ) - ) - response = await flow.begin_flow(turn_context) - - # verify flow_state - flow_state = flow.flow_state - assert flow_state.tag == FlowStateTag.COMPLETE - assert flow_state.user_token == "test_token" - # assert flow_state.flow_started is False # robrandao: TODO? - - # verify FlowResponse - assert response.flow_state == flow_state - assert response.sign_in_resource is None # No sign-in resource in this case - assert response.flow_error_tag == FlowErrorTag.NONE - assert response.token_response.token == "test_token" - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - # magic_code=None is an implementation detail, and ideally - # shouldn't be part of the test - magic_code=None - ) - - @pytest.mark.asyncio - async def test_begin_flow_long_case(self, mocker, turn_context): - # mock - dummy_sign_in_resource = SignInResource( - sign_in_link="https://example.com/signin", - token_exchange_state=TokenExchangeState(connection_name="test_connection") - ) - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock(return_value=dummy_sign_in_resource) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - ) - response = await flow.begin_flow(turn_context) - - # verify flow_state - flow_state = flow.flow_state - assert flow_state.tag == FlowStateTag.BEGIN - assert flow_state.token == "" - assert flow_state.flow_started is True - - # verify FlowResponse - assert response.flow_state == flow_state - assert response.sign_in_resource == dummy_sign_in_resource - assert response.flow_error_tag == FlowErrorTag.NONE - assert not response.token_response - # robrandao: TODO more assertions on sign_in_resource - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "flow_state", - [ - FlowState( - tag=FlowStateTag.BEGIN, - token="", - expires_at=datetime.now().timestamp() - 1, - attempts_remaining=3 - ), - FlowState( - tag=FlowStateTag.CONTINUE, - token="", - expires_at=datetime.now().timestamp() + 1000, - attempts_remaining=0 - ), - FlowState( - tag=FlowStateTag.FAILURE, - token="", - expires_at=datetime.now().timestamp() + 1000, - attempts_remaining=3 - ), - FlowState( - tag=FlowStateTag.COMPLETE, - token="", - expires_at=datetime.now().timestamp() + 1000, - attempts_remaining=2 - ), - ], - ) - async def test_continue_flow_not_active(self, mocker, turn_context, flow_state): - user_token_client = mocker.Mock() - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=flow_state - ) - flow_response = await flow.continue_flow(turn_context) - assert flow_response.flow_state == flow_state - assert not flow_response.token_response - - @pytest.fixture(params=[ - (FlowStateTag.CONTINUE, "test_token", 2), - (FlowStateTag.BEGIN, "", 1), - ]) - def active_flow_state(self, request): - tag, token, attempts_remaining = request.param - return FlowState( - tag=tag, - token=token, - expires_at=datetime.now().timestamp() + 1000, - attempts_remaining=attempts_remaining - ) - - async def test_continue_flow_message(self, mocker, turn_context, active_flow_state): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.text = "magic-message" - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock(return_value=None) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=mocker.Mock(), - flow_state=active_flow_state - ) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "magic_code", - ["magic-message", "", "abcdef", "@#0324", "231", None] - ) - async def test_continue_flow_message_format_error(self, mocker, turn_context, active_flow_state, magic_code): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.text = magic_code - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=mocker.Mock(), - flow_state=active_flow_state - ) - flow_response = flow.continue_flow(turn_context) - - # verify - assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining - assert not flow_response.token_response - assert flow_response.tag == FlowStateTag.FAILURE - assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT - - @pytest.mark.asyncio - async def test_continue_flow_message_magic_code_error(self, mocker, turn_context, active_flow_state): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.text = "123456" - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=active_flow_state - ) - flow_response = await flow.continue_flow(turn_context) - - # verify - assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining - assert not flow_response.token_response - assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_CODE - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - magic_code="123456" - ) - - @pytest.mark.asyncio - async def test_continue_flow_invoke_verify_state(self, mocker, turn_context, active_flow_state): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.name = "signin/verifyState" - turn_context.activity.value = {"state": "987654"} - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse(token="some-token")) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=active_flow_state - ) - flow_response = await flow.continue_flow(turn_context) - - # verify - assert active_flow_state.attempts_remaining == flow_response.flow_state.attempts_remaining - assert flow_response.token_response.token == "some-token" - assert flow_response.flow_state.tag == FlowStateTag.COMPLETE - assert flow_response.flow_error_tag == FlowErrorTag.NONE - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - magic_code="987654" - ) - - @pytest.mark.asyncio - async def test_continue_flow_invoke_verify_state_no_token(self, mocker, turn_context, active_flow_state): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.name = "signin/verifyState" - turn_context.activity.value = {"state": "987654"} - user_token_client = mocker.Mock() - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=active_flow_state - ) - flow_response = await flow.continue_flow(turn_context) - - # verify - assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining - assert not flow_response.token_response.token - if active_flow_state.attempts_remaining == 1: - assert flow_response.flow_state.tag == FlowStateTag.FAILURE - else: - assert flow_response.flow_state.tag == FlowStateTag.OTHER - assert flow_response.flow_error_tag == FlowErrorTag.OTHER - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - magic_code="987654" - ) - - @pytest.mark.asyncio - async def test_continue_flow_invoke_token_exchange(self, mocker, turn_context, active_flow_state): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.name = "signin/exchangeState" - turn_context.activity.value = "request_body" - user_token_client = mocker.Mock() - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token="exchange-token")) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=active_flow_state - ) - flow_response = await flow.continue_flow(turn_context) - - # verify - assert active_flow_state.attempts_remaining == flow_response.flow_state.attempts_remaining - assert flow_response.token_response.token == "exchange-token" - assert flow_response.flow_state.tag == FlowStateTag.COMPLETE - assert flow_response.flow_error_tag == FlowErrorTag.NONE - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - body="request_body" - ) - - @pytest.mark.asyncio - async def test_continue_flow_invoke_token_exchange_no_token(self, mocker, turn_context, active_flow_state): - # mock - turn_context.activity.type = ActivityTypes.message - turn_context.activity.name = "signin/exchangeState" - turn_context.activity.value = "request_body" - user_token_client = mocker.Mock() - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse()) - - # test - flow = AuthFlow( - abs_oauth_connection_name="test_connection", - user_token_client=user_token_client, - flow_state=active_flow_state - ) - flow_response = await flow.continue_flow(turn_context) - - # verify - assert active_flow_state.attempts_remaining - 1 == flow_response.flow_state.attempts_remaining - assert not flow_response.token_response - if active_flow_state.attempts_remaining == 1: - assert flow_response.flow_state.tag == FlowStateTag.FAILURE - else: - assert flow_response.flow_state.tag == FlowStateTag.CONTINUE - assert flow_response.flow_error_tag == FlowErrorTag.OTHER - user_token_client.user_token.get_token.assert_called_once_with( - user_id="__user_id", - connection_name="test_connection", - channel_id="__channel_id", - body="request_body" - ) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow(self): - assert True # robrandao: TODO \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py new file mode 100644 index 00000000..40fb8762 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py @@ -0,0 +1,10 @@ +import inspect + +def raise_if_empty_or_None(func_name, err=ValueError, **kwargs): + s = "" + for key, value in kwargs.items(): + if not value: + s += f"\tArgument '{key}' is required and cannot be None or empty.\n" + if s: + header = f"{func_name}: called with empty arguments:" + raise err(header + "\n" + s) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py b/libraries/microsoft-agents-hosting-core/tests/flow_storage_client_test.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/flow_storage_client_test.py rename to libraries/microsoft-agents-hosting-core/tests/flow_storage_client_test.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py b/libraries/microsoft-agents-hosting-core/tests/models_test.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/models_test.py rename to libraries/microsoft-agents-hosting-core/tests/models_test.py diff --git a/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py new file mode 100644 index 00000000..77b4693f --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py @@ -0,0 +1,402 @@ +from datetime import datetime +from typing import Callable + +import pytest +from pydantic import BaseModel + +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + TokenResponse, + SignInResource, + TokenExchangeState, + ConversationReference, + ChannelAccount, + ConversationAccount +) +from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow + +from microsoft.agents.hosting.core.app.oauth.models import ( + FlowErrorTag, + FlowState, + FlowStateTag, +) +from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase +from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase + +from .tools.oauth_test_utils import TEST_DEFAULTS + + +class TestAuthFlowUtils: + + def create_user_token_client(self, mocker, get_token_return=None): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + @pytest.fixture + def user_token_client(self, mocker): + return self.create_user_token_client(mocker, get_token_return=TEST_DEFAULTS.RES_TOKEN) + + def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): + # def conv_ref(): + # return mocker.MagicMock(spec=ConversationReference) + mock_conversation_ref = mocker.MagicMock(ConversationReference) + mocker.patch.object(Activity, "get_conversation_reference", return_value=mocker.MagicMock(ConversationReference)) + # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) + return Activity( + type=activity_type, + name=name, + from_property=ChannelAccount(id=TEST_DEFAULTS.USER_ID), + channel_id=TEST_DEFAULTS.CHANNEL_ID, + # get_conversation_reference=mocker.Mock(return_value=conv_ref), + relates_to=mocker.MagicMock(ConversationReference), + value=value, + text=text + ) + + @pytest.fixture(params=TEST_DEFAULTS.ALL()) + def sample_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=TEST_DEFAULTS.FAILED()) + def sample_failed_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=TEST_DEFAULTS.INACTIVE()) + def sample_inactive_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=TEST_DEFAULTS.ACTIVE()) + def sample_active_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture + def flow(self, sample_flow_state, user_token_client): + return AuthFlow(sample_flow_state, user_token_client) + + +class TestAuthFlow(TestAuthFlowUtils): + + def test_init_no_user_token_client(self, sample_flow_state): + with pytest.raises(ValueError): + AuthFlow(sample_flow_state, None) + + @pytest.mark.parametrize("missing_value", [ + "abs_oauth_connection_name", + "ms_app_id", + "channel_id", + "user_id" + ]) + def test_init_errors(self, missing_value, user_token_client): + flow_state = TEST_DEFAULTS.STARTED_FLOW.model_copy() + flow_state.__setattr__(missing_value, None) + with pytest.raises(ValueError): + AuthFlow(flow_state, user_token_client) + flow_state.__setattr__(missing_value, "") + with pytest.raises(ValueError): + AuthFlow(flow_state, user_token_client) + + def test_init_with_state(self, sample_flow_state, user_token_client): + flow = AuthFlow(sample_flow_state, user_token_client) + assert flow.flow_state == sample_flow_state + + def test_flow_state_prop_copy(self, flow): + flow_state = flow.flow_state + flow_state.user_id = (flow_state.user_id + "_copy") + assert flow.flow_state.user_id == TEST_DEFAULTS.USER_ID + assert flow_state.user_id == f"{TEST_DEFAULTS.USER_ID}_copy" + + @pytest.mark.asyncio + async def test_get_user_token_success(self, sample_flow_state, user_token_client): + # setup + flow = AuthFlow(sample_flow_state, user_token_client) + expected_final_flow_state = sample_flow_state + expected_final_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + + # test + token_response = await flow.get_user_token() + token = token_response.token + + # verify + assert token == TEST_DEFAULTS.RES_TOKEN + assert flow.flow_state == expected_final_flow_state + user_token_client.user_token.get_token.assert_called_once_with( + user_id=TEST_DEFAULTS.USER_ID, + connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, + channel_id=TEST_DEFAULTS.CHANNEL_ID, + magic_code=None + ) + + @pytest.mark.asyncio + async def test_get_user_token_failure(self, mocker, sample_flow_state): + # setup + user_token_client = self.create_user_token_client(mocker, get_token_return=None) + flow = AuthFlow(sample_flow_state, user_token_client) + expected_final_flow_state = flow.flow_state # robrandao: TODO -> what happens if fails and has user_token? + + # test + token_response = await flow.get_user_token() + + # verify + assert token_response == TokenResponse() + assert flow.flow_state == expected_final_flow_state + user_token_client.user_token.get_token.assert_called_once_with( + user_id=TEST_DEFAULTS.USER_ID, + connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, + channel_id=TEST_DEFAULTS.CHANNEL_ID, + magic_code=None + ) + + @pytest.mark.asyncio + async def test_sign_out(self, sample_flow_state, user_token_client): + # setup + flow = AuthFlow(sample_flow_state, user_token_client) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = "" + expected_flow_state.tag = FlowStateTag.NOT_STARTED + + # test + await flow.sign_out() + + # verify + user_token_client.user_token.sign_out.assert_called_once_with( + user_id=TEST_DEFAULTS.USER_ID, + connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, + channel_id=TEST_DEFAULTS.CHANNEL_ID + ) + assert flow.flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_client): + # setup + flow = AuthFlow(sample_flow_state, user_token_client) + activity = mocker.Mock(spec=Activity) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + + # test + response = await flow.begin_flow(activity) + + # verify + flow_state = flow.flow_state + assert flow_state == expected_flow_state + # assert flow_state.flow_started is False # robrandao: TODO? + + assert response.flow_state == flow_state + assert response.sign_in_resource is None # No sign-in resource in this case + assert response.flow_error_tag == FlowErrorTag.NONE + assert response.token_response.token == TEST_DEFAULTS.RES_TOKEN + user_token_client.user_token.get_token.assert_called_once_with( + user_id=TEST_DEFAULTS.USER_ID, + connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, + channel_id=TEST_DEFAULTS.CHANNEL_ID, + # magic_code=None is an implementation detail, and ideally + # shouldn't be part of the test + magic_code=None + ) + + @pytest.mark.asyncio + async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_client): + # mock + # tes = mocker.Mock(TokenExchangeState) + # tes.get_encoded_state = mocker.Mock(return_value="encoded_state") + mocker.patch.object(TokenExchangeState, "get_encoded_state", return_value="encoded_state") + dummy_sign_in_resource = SignInResource( + sign_in_link="https://example.com/signin", + token_exchange_state=mocker.Mock( + TokenExchangeState, get_encoded_state=mocker.Mock(return_value="encoded_state") + ) + ) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock( + return_value=dummy_sign_in_resource) + activity = self.create_activity(mocker) + + # setup + flow = AuthFlow(sample_flow_state, user_token_client) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = "" + expected_flow_state.tag = FlowStateTag.BEGIN + expected_flow_state.attempts_remaining = 3 + + # test + response = await flow.begin_flow(activity) + + # verify flow_state + flow_state = flow.flow_state + expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now + assert flow_state == response.flow_state + assert flow_state == expected_flow_state + + # verify FlowResponse + assert response.sign_in_resource == dummy_sign_in_resource + assert response.flow_error_tag == FlowErrorTag.NONE + assert not response.token_response + # robrandao: TODO more assertions on sign_in_resource + + @pytest.mark.asyncio + async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state, user_token_client): + # setup + activity = mocker.Mock() + flow = AuthFlow(sample_inactive_flow_state, user_token_client) + expected_flow_state = sample_inactive_flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + + # verify + # robrandao: TODO -> revise + assert flow_state == expected_flow_state + assert flow_response.flow_state == flow_state + assert not flow_response.token_response + + async def helper_continue_flow_failure(self, active_flow_state, user_token_client, activity, flow_error_tag): + # setup + flow = AuthFlow(active_flow_state, user_token_client) + expected_flow_state = active_flow_state + expected_flow_state.tag = FlowStateTag.CONTINUE if active_flow_state.attempts_remaining > 1 else FlowStateTag.FAILURE + expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining - 1 + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + + # verify + assert flow_response.flow_state == flow_state + assert expected_flow_state == flow_state + assert flow_response.token_response == TokenResponse() + assert flow_response.flow_error_tag == flow_error_tag + + async def helper_continue_flow_success(self, active_flow_state, user_token_client, activity): + # setup + flow = AuthFlow(active_flow_state, user_token_client) + expected_flow_state = active_flow_state + expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now + + # verify + assert flow_response.flow_state == flow_state + assert expected_flow_state == flow_state + assert flow_response.token_response == TokenResponse(token=TEST_DEFAULTS.RES_TOKEN) + assert flow_response.flow_error_tag == FlowErrorTag.NONE + + @pytest.mark.asyncio + @pytest.mark.parametrize("magic_code", ["magic", "123", "", "1239453"]) + async def test_continue_flow_active_message_magic_format_error(self, mocker, sample_active_flow_state, user_token_client, magic_code): + # setup + activity = self.create_activity(mocker, ActivityTypes.message, text=magic_code) + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_FORMAT) + user_token_client.assert_not_called() + + @pytest.mark.asyncio + async def test_continue_flow_active_message_magic_code_error(self, mocker, sample_active_flow_state, user_token_client): + # setup + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + activity = self.create_activity(mocker, ActivityTypes.message, text="123456") + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_CODE_INCORRECT) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.abs_oauth_connection_name, + channel_id=sample_active_flow_state.channel_id, + magic_code="123456" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_message_success(self, mocker, sample_active_flow_state, user_token_client): + # setup + activity = self.create_activity(mocker, ActivityTypes.message, text="123456") + await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.abs_oauth_connection_name, + channel_id=sample_active_flow_state.channel_id, + magic_code="123456" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_verify_state_error(self, mocker, sample_active_flow_state, user_token_client): + # setup + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ + "state": "magic_code" + }) + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.abs_oauth_connection_name, + channel_id=sample_active_flow_state.channel_id, + magic_code="magic_code" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_verify_success(self, mocker, sample_active_flow_state, user_token_client): + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ + "state": "magic_code" + }) + await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.abs_oauth_connection_name, + channel_id=sample_active_flow_state.channel_id, + magic_code="magic_code" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_token_exchange_error(self, mocker, sample_active_flow_state, user_token_client): + token_exchange_request = {} + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse()) + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.abs_oauth_connection_name, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_token_exchange_success(self, mocker, sample_active_flow_state, user_token_client): + token_exchange_request = {} + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token=TEST_DEFAULTS.RES_TOKEN)) + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) + await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.abs_oauth_connection_name, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request + ) + + @pytest.mark.asyncio + async def test_continue_flow_invalid_invoke_name(self, mocker, sample_active_flow_state, user_token_client): + with pytest.raises(ValueError): + activity = self.create_activity(mocker, ActivityTypes.invoke, name="other", value={}) + flow = AuthFlow(sample_active_flow_state, user_token_client) + await flow.continue_flow(activity) + + @pytest.mark.asyncio + async def test_continue_flow_invalid_activity_type(self, mocker, sample_active_flow_state, user_token_client): + with pytest.raises(ValueError): + activity = self.create_activity(mocker, ActivityTypes.command, name="other", value={}) + flow = AuthFlow(sample_active_flow_state, user_token_client) + await flow.continue_flow(activity) + + # robrandao: TODO -> test begin_or_continue_flow \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index e7ee8b69..fbaa5e1b 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -1,213 +1,456 @@ +import datetime + import pytest -from .tools.testing_authorization import ( - TestingAuthorization, - create_test_auth_handler, + +from microsoft.agents.activity import ( + TokenResponse +) +from microsoft.agents.hosting.core import ( + Authorization, + MemoryStorage, + FlowStorageClient, + FlowState, + FlowErrorTag, + FlowStateTag, + FlowResponse +) +from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline +from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase +from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase + +from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow + +from tools.oauth_test_utils import ( + TEST_DEFAULTS, + STORAGE_INIT_DATA ) -from .tools.testing_utility import TestingUtility -import jwt -from unittest.mock import Mock, AsyncMock -from microsoft.agents.hosting.core import SignInState -from microsoft.agents.hosting.core.oauth_flow import FlowState +from tools.testing_authorization import ( + TestingTokenProvider, + TestingConnectionManager, + create_test_auth_handler +) -class TestAuthorization: - def setup_method(self): - self.turn_context = TestingUtility.create_empty_context() +class TestAuthFlowUtils: - @pytest.mark.asyncio - async def test_get_token_single_handler(self): - """ - Test Authorization - get_token() with single Auth Handler - """ - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - } + def create_context(self, + mocker, + channel_id="__channel_id", + user_id="__user_id", + abs_oauth_connection_name="graph", + user_token_client=None): + + if not user_token_client: + user_token_client = self.create_mock_user_token_client(mocker) + + turn_context = mocker.Mock() + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = user_id + turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" + turn_context.turn_state = { + "__user_token_client": user_token_client + } + return context + + def create_mock_user_token_client( + self, + mocker, + token=None, + ): + mock_user_token_client_class = mocker.Mock(spec=UserTokenClientBase) + mock_user_token_client_class.user_token = mocker.Mock(spec=UserTokenBase) + mock_user_token_client_class.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse(token=token) ) + mock_user_token_client_class.user_token.sign_out = mocker.AsyncMock() + return mock_user_token_client_class + + @pytest.fixture + def mock_user_token_client_class(self, mocker): + return self.create_mock_user_token_client_class(mocker) + + @pytest.fixture + def mock_flow_class(self, mocker): + mock_flow_class = mocker.Mock(spec=AuthFlow) - token_res = await auth.get_token(self.turn_context) - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + mocker.patch.object(AuthFlow, "__init__", return_value=mock_flow_class) + mock_flow_class.get_user_token = mocker.AsyncMock() + mock_flow_class.sign_out = mocker.AsyncMock() + + return mock_flow_class + + @pytest.fixture + def turn_context(self, mocker): + return self.create_context(mocker, "__channel_id", "__user_id", "__connection") - @pytest.mark.asyncio - async def test_get_token_multiple_handlers(self): - """ - Test Authorization - get_token() with multiple Auth Handlers - """ - auth_handlers = { - "auth-handler": create_test_auth_handler("test-auth-a"), - "auth-handler-obo": create_test_auth_handler("test-auth-b", obo=True), - "auth-handler-with-title": create_test_auth_handler( - "test-auth-c", title="test-title" - ), - "auth-handler-with-title-text": create_test_auth_handler( - "test-auth-d", title="test-title", text="test-text" + def create_user_token_client(self, mocker, get_token_return=None): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + @pytest.fixture + def user_token_client(self, mocker): + return self.create_user_token_client(mocker, get_token_return=TEST_DEFAULTS.RES_TOKEN) + + def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): + # def conv_ref(): + # return mocker.MagicMock(spec=ConversationReference) + mock_conversation_ref = mocker.MagicMock(ConversationReference) + mocker.patch.object(Activity, "get_conversation_reference", return_value=mocker.MagicMock(ConversationReference)) + # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) + return Activity( + type=activity_type, + name=name, + from_property=ChannelAccount(id=TEST_DEFAULTS.USER_ID), + channel_id=TEST_DEFAULTS.CHANNEL_ID, + # get_conversation_reference=mocker.Mock(return_value=conv_ref), + relates_to=mocker.MagicMock(ConversationReference), + value=value, + text=text + ) + + @pytest.fixture(params=TEST_DEFAULTS.ALL()) + def sample_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=TEST_DEFAULTS.FAILED()) + def sample_failed_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=TEST_DEFAULTS.INACTIVE()) + def sample_inactive_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=TEST_DEFAULTS.ACTIVE()) + def sample_active_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture + def flow(self, sample_flow_state, user_token_client): + return AuthFlow(sample_flow_state, user_token_client) + + @pytest.fixture + def connection_manager(self): + pass + + @pytest.fixture + def auth_handlers(self): + return { + "handler": AuthHandler( + name="handler", + title="Test Handler", + text="Text" + abs_oauth_connection_name="handler", + obo_connection_name="obo" ), + "connection": AuthHandler( + name="connection", + title="Test Handler", + text="Text" + abs_oauth_connection_name="connection", + obo_connection_name="obo" + ) } - auth = TestingAuthorization(auth_handlers=auth_handlers) - for id, auth_handler in auth_handlers.items(): - # test value propogation - token_res = await auth.get_token(self.turn_context, id) - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + + @pytest.fixture + def auth(self, connection_manager, storage, auth_handlers): + return Authorization(connection_manager, storage, auth_handlers) + +class TestAuthorizationUtils: + + def create_user_token_client(self, mocker, get_token_return=None): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + def create_storage(self): + return MemoryStorage(STORAGE_INIT_DATA()) + + @pytest.fixture + def storage(self): + return self.create_storage() + + @pytest.fixture + def baseline_storage(self): + return StorageBaseline(STORAGE_INIT_DATA()) + + def mock_user_token_provider + + def patch_flow(self, mocker, flow_response=None, token=None,): + mocker.patch.object(AuthFlow, "get_user_token", return_value=TokenResponse(token=token)) + mocker.patch.object(AuthFlow, "sign_out") + mocker.patch.object(AuthFlow, "begin_or_continue_flow", return_value=flow_response) + +class TestAuthorization(TestAuthorizationUtils): + + def test_init(self, mocker): + pass @pytest.mark.asyncio - async def test_exchange_token_valid_token(self): - valid_token = jwt.encode({"aud": "api://botframework.test.api"}, "") - scopes = ["scope-a"] - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth", obo=True), - }, - token=valid_token, - ) - token_res = await auth.exchange_token(self.turn_context, scopes=scopes) - assert ( - token_res.token - == f"{auth.resolver_handler().obo_connection_name}-obo-token" - ) + @pytest.mark.parametrize("auth_handler_id", ["", "handler", "missing_handler"]) + async def test_open_flow_value_error(self, auth, context, auth_handler_id): + with pytest.raises(ValueError): + async with auth.open_flow(context, auth_handler_id): + pass @pytest.mark.asyncio - async def test_exchange_token_invalid_token(self): - invalid_token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - scopes = ["scope-a"] - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth"), - }, - token=invalid_token, + @pytest.mark.parametrize( + ", from_property_id, auth_handler_id", + [ + ("channel_id", "user_id", "expired"), + ("teams_id", "Bob", "no_retries"), + ("channel", "Alice", "begin"), + ("channel", "Alice", "continue"), + ("channel", "Alice", "expired_and_retries"), + ("channel", "Alice", "not_started"), + ] + ) + async def test_open_flow_readonly_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): + # setup + storage = MemoryStorage(STORAGE_SAMPLE_DICT) + baseline = StorageBaseline(STORAGE_SAMPLE_DICT) + auth = Authorization( + storage, + connection_manager, + auth_handlers ) - token_res = await auth.exchange_token(self.turn_context, scopes=scopes) - assert token_res.token == invalid_token + context = self.build_context(mocker, channel_id, from_property_id) + storage_client = FlowStorageClient(context, storage) + key = storage_client.key(auth_handler_id) + expected_init_flow_state = storage.read(key, FlowState) + + # code + async with auth.open_flow(context, "handler", readonly=True) as flow: + actual_init_flow_state = flow.flow_state.copy() + flow.flow_state.id = "garbage" + flow.flow_state.tag = FlowStateTag.FAILURE + flow.flow_state.expires = 0 + flow.flow_state.attempts_remaining = -1 + actual_final_flow_state = await storage.read([key], FlowState)[key] + + # verify + expected_final_flow_state = baseline.read(key, FlowState) + assert actual_init_flow_state == expected_init_flow_state + assert actual_final_flow_state == expected_final_flow_state + assert await baseline.equals(storage) @pytest.mark.asyncio - async def test_get_flow_state_unavailable(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - } + @pytest.mark.parametrize( + "channel_id, from_property_id, auth_handler_id", + [ + ("channel_id", "user_id", "expired"), + ("teams_id", "Bob", "no_retries"), + ("channel", "Alice", "begin"), + ("channel", "Alice", "continue"), + ("channel", "Alice", "expired_and_retries"), + ("channel", "Alice", "not_started"), + ] + ) + async def test_open_flow_storage_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): + # setup + storage = MemoryStorage(STORAGE_SAMPLE_DICT) + baseline = StorageBaseline(STORAGE_SAMPLE_DICT) + auth = Authorization( + storage, + connection_manager, + auth_handlers ) + context = self.build_context(mocker, channel_id, from_property_id) + storage_client = FlowStorageClient(context, storage) + key = storage_client.key(auth_handler_id) + expected_init_flow_state = storage.read(key, FlowState) + + # code + async with auth.open_flow(context, "handler") as flow: + actual_init_flow_state = flow.flow_state.copy() + flow.flow_state.id = "garbage" + flow.flow_state.tag = FlowStateTag.FAILURE + flow.flow_state.expires = 0 + flow.flow_state.attempts_remaining = -1 - assert auth.get_flow_state() == FlowState() + # verify + baseline.write({ + "auth/channel/Alice/continue": flow.flow_state + }) + expected_final_flow_state = baseline.read(key, FlowState) + assert await baseline.equals(storage) + assert actual_init_flow_state == expected_init_flow_state + assert flow.flow_state == expected_final_flow_state @pytest.mark.asyncio - async def test_begin_or_continue_flow_not_started(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, + async def test_get_token(self, mocker, m_storage): + m_storage.read.return_value = FlowState( + id="auth_handler", + tag=FlowStateTag.ACTIVE, + expires=3600, + attempts_remaining=3 ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) + expected = TokenResponse( + access_token="access_token", + refresh_token="refresh_token", + expires_in=3600 + ) + mock_flow = mocker.AsyncMock() + mock_flow.get_user_token.return_value = expected + mocker.patch.object("OAuthFlow", "get_token", return_value=expected) + mocker.patch.object("OAuthFlow", "__init__", return_value=mock_flow) + + assert await auth.get_token("auth_handler") is expected + assert mock_flow.get_user_token.called_once() - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth, context, auth_handler_id", + [ + (lazy_fixture("auth"), lazy_fixture("context"), "missing-handler"), + (lazy_fixture("auth"), lazy_fixture("context"), ""), + (lazy_fixture("auth"), None, "handler") + ] + ) + async def test_get_token_error(self, auth, context, auth_handler_id): + with pytest.raises(ValueError): + await auth.get_token(context, auth_handler_id) + + @pytest.fixture + def valid_token_response(self): + return TokenResponse( + connection_name="connection", + token="token" ) - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.begin_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.set_value.assert_called_once_with( - auth.SIGN_IN_STATE_KEY, - SignInState( - continuation_activity=self.turn_context.activity, - handler_id="auth-handler", - ), + + @pytest.fixture + def invalid_exchange_token(self): + token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") + return TokenResponse( + connection_name="connection" + token=token ) @pytest.mark.asyncio - async def test_begin_or_continue_flow_started(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - flow_started=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", + async def test_exchange_token( + self, + mock_user_token_client_class, + ): + + mocker.patch.object("OAuthFlow", + get_user_token=mocker.AsyncMock(return_value=TokenResponse( + access_token="access_token", + refresh_token="refresh_token", + expires_in=3600 + )) ) + mock_user_token_client_class - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + @pytest.mark.asyncio + @pytest.mark.parametrize( + "channel_id, user_id, expected_flow_state", + [ + [] + ] + ) + async def test_get_active_flow_state(self, mocker, auth, channel_id, user_id, expected_flow_state): + context = self.create_context(mocker, channel_id, user_id) + actual_flow_state = await auth.get_active_flow_state(context) + assert actual_flow_state == expected_flow_state - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) + @pytest.mark.asyncio + async def test_get_active_flow_state_missing(self, mocker, auth): + context = self.create_context(mocker, "__channel_id", "__user_id") + res = await auth.get_active_flow_state(context) + assert res is None @pytest.mark.asyncio - async def test_begin_or_continue_flow_started_sign_in_success(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - flow_started=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - auth.on_sign_in_success(AsyncMock()) + async def begin_or_continue_flow( + self, + mocker, + turn_context, + storage, + baseline_storage, + connection_manager, + auth_handlers + ): + pass - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) + @pytest.mark.parametrize("auth_handler_id", ["handler", "connection"]) + def test_resolve_handler_specified(self, auth, auth_handlers, auth_handler_id): + assert auth.resolve_handler(auth_handler_id) == auth_handlers[auth_handler_id] - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) - auth._sign_in_handler.assert_called_once_with( - self.turn_context, mock_turn_state, "auth-handler" - ) + def test_resolve_handler_error(self, auth): + with pytest.raises(ValueError): + auth.resolve_handler("missing-handler") + + def test_resolve_handler_first(self, auth, auth_handlers_list): + assert auth.resolve_handler() == auth_handlers_list[0] @pytest.mark.asyncio - async def test_begin_or_continue_flow_started_sign_in_failure(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - sign_in_failed=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - auth.on_sign_in_failure(AsyncMock()) + async def test_sign_out_individual( + self, + mock_user_token_client_class, + mock_flow_class, + turn_context, + storage, + baseline_storage, + connection_manager, + auth_handlers + ): + # setup + storage_client = FlowStorageClient(turn_context, storage) - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) + auth = Authorization(storage, connection_manager, auth_handlers) + await auth.sign_out("handler") - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert not token_res + await baseline_storage.delete([storage_client.key("handler")]) - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - auth._sign_in_failed_handler.assert_called_once_with( - self.turn_context, mock_turn_state, "auth-handler" - ) + # verify storage + assert await baseline_storage.equals(storage) + + # verify flow + mock_flow_class.sign_out.assert_called_once_with("handler") + mock_user_token_client_class.user_token.sign_out.assert_called_once() + + @pytest.mark.asyncio + async def test_sign_out_all( + self, + mock_user_token_client_class, + mock_flow_class, + turn_context, + storage, + baseline_storage, + connection_manager, + auth_handlers + ): + # setup + storage_client = FlowStorageClient(turn_context, storage) + + auth = Authorization(storage, connection_manager, auth_handlers) + await auth.sign_out("handler") + + await baseline_storage.delete([storage_client.key("handler"), storage_client.key("connection")]) + + # verify storage + assert await baseline_storage.equals(storage) + + # verify flow + mock_flow_class.sign_out.assert_called_once_with("handler") + mock_flow_class.sign_out.assert_called_once_with("connection") + + + # robrandao: TODO -> handlers \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py new file mode 100644 index 00000000..c4fb6c27 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import uuid +from datetime import datetime, timezone +from typing import Callable, List, Optional, Awaitable +from collections import deque + +from microsoft.agents.hosting.core.authorization import ClaimsIdentity +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + ChannelAccount, + ConversationAccount, + ConversationReference, + Channels, + ResourceResponse, + RoleTypes, + InvokeResponse, +) +from microsoft.agents.hosting.core.channel_adapter import ChannelAdapter +from microsoft.agents.hosting.core.turn_context import TurnContext +from microsoft.agents.hosting.core.connector import UserTokenClient + +AgentCallbackHandler = Callable[[TurnContext], Awaitable] + + +# patch userTokenclient constructor +class MockUserTokenClient(UserTokenClient): + """A mock user token client for testing.""" + + def __init__(self, ...): + self._store = {} + self._exchange_store = {} + self._throw_on_exchange = {} + self._user_token = mocker.Mock() + self._agent_sign_in = mocker.Mock() + + def add_user_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + token: str, + magic_code: str = None, + ): + """Add a token for a user that can be retrieved during testing.""" + key = self._get_key(connection_name, channel_id, user_id) + self._store[key] = (token, magic_code) + + def add_exchangeable_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + token: str, + ): + """Add an exchangeable token for a user that can be exchanged during testing.""" + key = self._get_exchange_key( + connection_name, channel_id, user_id, exchangeable_item + ) + self._exchange_store[key] = token + + def throw_on_exchange_request( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + ): + """Add an instruction to throw an exception during exchange requests.""" + key = self._get_exchange_key( + connection_name, channel_id, user_id, exchangeable_item + ) + self._throw_on_exchange[key] = True + + def _get_key(self, connection_name: str, channel_id: str, user_id: str) -> str: + return f"{connection_name}:{channel_id}:{user_id}" + + def _get_exchange_key( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + ) -> str: + return f"{connection_name}:{channel_id}:{user_id}:{exchangeable_item}" diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py b/libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py new file mode 100644 index 00000000..db40c96d --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py @@ -0,0 +1,142 @@ +from datetime import datetime +from microsoft.agents.hosting.core.app.oauth.models import FlowState, FlowStateTag + +class TEST_DEFAULTS: + + MS_APP_ID = "__ms_app_id" + CHANNEL_ID = "__channel_id" + USER_ID = "__user_id" + ABS_OAUTH_CONNECTION_NAME = "__connection_name" + RES_TOKEN = "__res_token" + + DEF_ARGS = { + "ms_app_id": MS_APP_ID, + "channel_id": CHANNEL_ID, + "user_id": USER_ID, + "abs_oauth_connection_name": ABS_OAUTH_CONNECTION_NAME + } + + STARTED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + user_token="____", + expires_at=datetime.now().timestamp() + 1000000 + ) + STARTED_FLOW_ONE_RETRY = FlowState( + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=2, + user_token="____", + expires_at=datetime.now().timestamp() + 1000000 + ) + ACTIVE_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expires_at=datetime.now().timestamp() + 1000000 + ) + ACTIVE_FLOW_ONE_RETRY = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + user_token="__token", + expires_at=datetime.now().timestamp() + 1000000 + ) + ACTIVE_EXP_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expires_at=datetime.now().timestamp() + ) + COMPLETED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.COMPLETE, + attempts_remaining=2, + user_token="test_token", + expires_at=datetime.now().timestamp() + 1000000 + ) + FAIL_BY_ATTEMPTS_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.FAILURE, + attempts_remaining=0, + expires_at=datetime.now().timestamp() + 1000000 + ) + + FAIL_BY_EXP_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.FAILURE, + attempts_remaining=2, + expires_at=0 + ) + + @classmethod + def __format(cls, lst): + return [ flow_state.model_copy() for flow_state in lst ] + + @classmethod + def ALL(cls): + return cls.__format([ + cls.STARTED_FLOW, + cls.STARTED_FLOW_ONE_RETRY, + cls.ACTIVE_FLOW, + cls.ACTIVE_FLOW_ONE_RETRY, + cls.ACTIVE_EXP_FLOW, + cls.COMPLETED_FLOW, + cls.FAIL_BY_ATTEMPTS_FLOW, + cls.FAIL_BY_EXP_FLOW + ]) + + @classmethod + def FAILED(cls): + return cls.__format([ + cls.ACTIVE_EXP_FLOW, + cls.FAIL_BY_ATTEMPTS_FLOW, + cls.FAIL_BY_EXP_FLOW + ]) + + @classmethod + def ACTIVE(cls): + return cls.__format([ + cls.STARTED_FLOW, + cls.STARTED_FLOW_ONE_RETRY, + cls.ACTIVE_FLOW, + cls.ACTIVE_FLOW_ONE_RETRY, + ]) + + @classmethod + def INACTIVE(cls): + return cls.__format([ + cls.ACTIVE_EXP_FLOW, + cls.COMPLETED_FLOW, + cls.FAIL_BY_ATTEMPTS_FLOW, + cls.FAIL_BY_EXP_FLOW + ]) + +def flow_key(channel_id, user_id, handler_id): + return f"auth/{channel_id}/{user_id}/{handler_id}" + +STORAGE_SAMPLE_DICT = { + "user_id": "123", + "session_id": "abc", + flow_key("webchat", "Alice", "graph"): TEST_DEFAULTS.COMPLETED_FLOW.model_copy(), + flow_key("webchat", "Alice", "github"): TEST_DEFAULTS.ACTIVE_FLOW_ONE_RETRY.model_copy(), + flow_key("teams", "Alice", "graph"): TEST_DEFAULTS.STARTED_FLOW.model_copy(), + flow_key("webchat", "Bob", "graph"): TEST_DEFAULTS.ACTIVE_EXP_FLOW.model_copy(), + flow_key("teams", "Bob", "slack"): TEST_DEFAULTS.STARTED_FLOW.model_copy(), + flow_key("webchat", "Chuck", "github"): TEST_DEFAULTS.FAIL_BY_ATTEMPTS_FLOW.model_copy(), +} + +def STORAGE_INIT_DATA(): + data = STORAGE_SAMPLE_DICT.copy() + for key, value in data.items(): + data[key] = value.model_copy() if isinstance(value, FlowState) else value + return data + +def update_data_with_flow_state(data, channel_id, user_id, auth_handler_id, flow_state): + data = data.copy() + key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" + data[key] = flow_state.model_copy() + return data \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py index b3574b8f..fb0fa9b0 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py @@ -25,66 +25,7 @@ AgentCallbackHandler = Callable[[TurnContext], Awaitable] - -class MockUserTokenClient(UserTokenClient): - """A mock user token client for testing.""" - - def __init__(self): - self._store = {} - self._exchange_store = {} - self._throw_on_exchange = {} - - def add_user_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - token: str, - magic_code: str = None, - ): - """Add a token for a user that can be retrieved during testing.""" - key = self._get_key(connection_name, channel_id, user_id) - self._store[key] = (token, magic_code) - - def add_exchangeable_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - token: str, - ): - """Add an exchangeable token for a user that can be exchanged during testing.""" - key = self._get_exchange_key( - connection_name, channel_id, user_id, exchangeable_item - ) - self._exchange_store[key] = token - - def throw_on_exchange_request( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - ): - """Add an instruction to throw an exception during exchange requests.""" - key = self._get_exchange_key( - connection_name, channel_id, user_id, exchangeable_item - ) - self._throw_on_exchange[key] = True - - def _get_key(self, connection_name: str, channel_id: str, user_id: str) -> str: - return f"{connection_name}:{channel_id}:{user_id}" - - def _get_exchange_key( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - ) -> str: - return f"{connection_name}:{channel_id}:{user_id}:{exchangeable_item}" - +from .mock_user_token_client import MockUserTokenClient class TestingAdapter(ChannelAdapter): """ diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py index ac90b6c6..69f04ab3 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py @@ -207,6 +207,7 @@ def __init__( storage=storage, auth_handlers=auth_handlers, connection_manager=connection_manager, + service_url="a" ) # Configure each auth handler with mock OAuth flow behavior From 7390df12893c723559643452a87f21b1edf8473f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Wed, 20 Aug 2025 21:30:00 -0700 Subject: [PATCH 11/32] Adding more authorization tests --- .../tests/old_test_authorization.py | 213 +++++++++ .../tests/test_authorization.py | 411 ++++++++++++------ 2 files changed, 502 insertions(+), 122 deletions(-) create mode 100644 libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py diff --git a/libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py new file mode 100644 index 00000000..e7ee8b69 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py @@ -0,0 +1,213 @@ +import pytest +from .tools.testing_authorization import ( + TestingAuthorization, + create_test_auth_handler, +) +from .tools.testing_utility import TestingUtility +import jwt +from unittest.mock import Mock, AsyncMock +from microsoft.agents.hosting.core import SignInState +from microsoft.agents.hosting.core.oauth_flow import FlowState + + +class TestAuthorization: + def setup_method(self): + self.turn_context = TestingUtility.create_empty_context() + + @pytest.mark.asyncio + async def test_get_token_single_handler(self): + """ + Test Authorization - get_token() with single Auth Handler + """ + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth-a"), + } + ) + + token_res = await auth.get_token(self.turn_context) + auth_handler = auth.resolver_handler("auth-handler") + assert token_res.connection_name == auth_handler.abs_oauth_connection_name + assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + + @pytest.mark.asyncio + async def test_get_token_multiple_handlers(self): + """ + Test Authorization - get_token() with multiple Auth Handlers + """ + auth_handlers = { + "auth-handler": create_test_auth_handler("test-auth-a"), + "auth-handler-obo": create_test_auth_handler("test-auth-b", obo=True), + "auth-handler-with-title": create_test_auth_handler( + "test-auth-c", title="test-title" + ), + "auth-handler-with-title-text": create_test_auth_handler( + "test-auth-d", title="test-title", text="test-text" + ), + } + auth = TestingAuthorization(auth_handlers=auth_handlers) + for id, auth_handler in auth_handlers.items(): + # test value propogation + token_res = await auth.get_token(self.turn_context, id) + assert token_res.connection_name == auth_handler.abs_oauth_connection_name + assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + + @pytest.mark.asyncio + async def test_exchange_token_valid_token(self): + valid_token = jwt.encode({"aud": "api://botframework.test.api"}, "") + scopes = ["scope-a"] + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth", obo=True), + }, + token=valid_token, + ) + token_res = await auth.exchange_token(self.turn_context, scopes=scopes) + assert ( + token_res.token + == f"{auth.resolver_handler().obo_connection_name}-obo-token" + ) + + @pytest.mark.asyncio + async def test_exchange_token_invalid_token(self): + invalid_token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") + scopes = ["scope-a"] + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth"), + }, + token=invalid_token, + ) + token_res = await auth.exchange_token(self.turn_context, scopes=scopes) + assert token_res.token == invalid_token + + @pytest.mark.asyncio + async def test_get_flow_state_unavailable(self): + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth-a"), + } + ) + + assert auth.get_flow_state() == FlowState() + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_not_started(self): + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth-a"), + }, + token=None, + ) + mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) + + token_res = await auth.begin_or_continue_flow( + self.turn_context, + mock_turn_state, + "auth-handler", + ) + # Test value propogation + auth_handler = auth.resolver_handler("auth-handler") + assert token_res.connection_name == auth_handler.abs_oauth_connection_name + assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + + # Test function calls + auth_handler.flow._get_flow_state.assert_called_once() + auth_handler.flow.begin_flow.assert_called_once() + mock_turn_state.save.assert_called_once_with(self.turn_context) + mock_turn_state.set_value.assert_called_once_with( + auth.SIGN_IN_STATE_KEY, + SignInState( + continuation_activity=self.turn_context.activity, + handler_id="auth-handler", + ), + ) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_started(self): + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth-a"), + }, + token=None, + flow_started=True, + ) + mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) + token_res = await auth.begin_or_continue_flow( + self.turn_context, + mock_turn_state, + "auth-handler", + ) + + # Test value propogation + auth_handler = auth.resolver_handler("auth-handler") + assert token_res.connection_name == auth_handler.abs_oauth_connection_name + assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + + # Test function calls + auth_handler.flow._get_flow_state.assert_called_once() + auth_handler.flow.continue_flow.assert_called_once() + mock_turn_state.save.assert_called_once_with(self.turn_context) + mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_started_sign_in_success(self): + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth-a"), + }, + token=None, + flow_started=True, + ) + mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) + auth.on_sign_in_success(AsyncMock()) + + token_res = await auth.begin_or_continue_flow( + self.turn_context, + mock_turn_state, + "auth-handler", + ) + + # Test value propogation + auth_handler = auth.resolver_handler("auth-handler") + assert token_res.connection_name == auth_handler.abs_oauth_connection_name + assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" + + # Test function calls + auth_handler.flow._get_flow_state.assert_called_once() + auth_handler.flow.continue_flow.assert_called_once() + mock_turn_state.save.assert_called_once_with(self.turn_context) + mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) + auth._sign_in_handler.assert_called_once_with( + self.turn_context, mock_turn_state, "auth-handler" + ) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_started_sign_in_failure(self): + auth = TestingAuthorization( + auth_handlers={ + "auth-handler": create_test_auth_handler("test-auth-a"), + }, + token=None, + sign_in_failed=True, + ) + mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) + auth.on_sign_in_failure(AsyncMock()) + + token_res = await auth.begin_or_continue_flow( + self.turn_context, + mock_turn_state, + "auth-handler", + ) + + # Test value propogation + auth_handler = auth.resolver_handler("auth-handler") + assert not token_res + + # Test function calls + auth_handler.flow._get_flow_state.assert_called_once() + auth_handler.flow.continue_flow.assert_called_once() + mock_turn_state.save.assert_called_once_with(self.turn_context) + auth._sign_in_failed_handler.assert_called_once_with( + self.turn_context, mock_turn_state, "auth-handler" + ) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index fbaa5e1b..7ed13c58 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -3,6 +3,7 @@ import pytest from microsoft.agents.activity import ( + ActivityTypes, TokenResponse ) from microsoft.agents.hosting.core import ( @@ -12,7 +13,8 @@ FlowState, FlowErrorTag, FlowStateTag, - FlowResponse + FlowResponse, + storage ) from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase @@ -50,7 +52,7 @@ def create_context(self, turn_context.turn_state = { "__user_token_client": user_token_client } - return context + return turn_context def create_mock_user_token_client( self, @@ -138,28 +140,18 @@ def sample_active_flow_state(self, request): def flow(self, sample_flow_state, user_token_client): return AuthFlow(sample_flow_state, user_token_client) - @pytest.fixture - def connection_manager(self): - pass - @pytest.fixture def auth_handlers(self): - return { - "handler": AuthHandler( - name="handler", - title="Test Handler", - text="Text" - abs_oauth_connection_name="handler", - obo_connection_name="obo" - ), - "connection": AuthHandler( - name="connection", - title="Test Handler", - text="Text" - abs_oauth_connection_name="connection", - obo_connection_name="obo" - ) - } + handlers = {} + for key in STORAGE_INIT_DATA().keys(): + if key.startswith("auth/"): + auth_handler_name = key[key.rindex("/")+1:] + handlers[auth_handler_name] = create_test_auth_handler(auth_handler_name, True) + return handlers + + @pytest.fixture + def connection_manager(self): + return TestingConnectionManager() @pytest.fixture def auth(self, connection_manager, storage, auth_handlers): @@ -201,133 +193,308 @@ def patch_flow(self, mocker, flow_response=None, token=None,): class TestAuthorization(TestAuthorizationUtils): - def test_init(self, mocker): - pass + def test_init_configuration_variants(self,storage, connection_manager, auth_handlers): + AGENTAPPLICATION = { + "USERAUTHORIZATION": { + "HANDLERS": { + handler_name: { + "SETTINGS": { + "title": handler.title, + "text": handler.text, + "abs_oauth_connection_name": handler.abs_oauth_connection_name, + "obo_connection_name": handler.obo_connection_name + } + } for handler_name, handler in auth_handlers.items() + } + } + } + auth_with_config_obj = Authorization( + storage, + connection_manager, + auth_handlers=None, + AGENTAPPLICATION=AGENTAPPLICATION + ) + auth_with_handlers_list = Authorization( + storage, + connection_manager, + auth_handlers=auth_handlers + ) + for auth_handler_name in auth_handlers.keys(): + auth_handler_a = auth_with_config_obj.resolve_handler(auth_handler_name) + auth_handler_b = auth_with_handlers_list.resolve_handler(auth_handler_name) + assert auth_handler_a == auth_handler_b @pytest.mark.asyncio - @pytest.mark.parametrize("auth_handler_id", ["", "handler", "missing_handler"]) - async def test_open_flow_value_error(self, auth, context, auth_handler_id): + @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", + [ + ["", "webchat", "Alice"] + ["handler", "teams", "Bob"] + ]) + async def test_open_flow_value_error( + self, + mocker, + auth, + auth_handler_id, + channel_id, + user_id + ): + context = self.create_context(mocker, channel_id, user_id) with pytest.raises(ValueError): async with auth.open_flow(context, auth_handler_id): pass @pytest.mark.asyncio - @pytest.mark.parametrize( - ", from_property_id, auth_handler_id", + @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", [ - ("channel_id", "user_id", "expired"), - ("teams_id", "Bob", "no_retries"), - ("channel", "Alice", "begin"), - ("channel", "Alice", "continue"), - ("channel", "Alice", "expired_and_retries"), - ("channel", "Alice", "not_started"), - ] - ) - async def test_open_flow_readonly_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): + ["", "webchat", "Alice"], + ["handler", "teams", "Bob"] + ]) + async def test_open_flow_readonly( + self, + storage, + connection_client, + auth_handlers, + auth_handler_id, + channel_id, + user_id + ): # setup - storage = MemoryStorage(STORAGE_SAMPLE_DICT) - baseline = StorageBaseline(STORAGE_SAMPLE_DICT) - auth = Authorization( - storage, - connection_manager, - auth_handlers - ) - context = self.build_context(mocker, channel_id, from_property_id) - storage_client = FlowStorageClient(context, storage) - key = storage_client.key(auth_handler_id) - expected_init_flow_state = storage.read(key, FlowState) - - # code - async with auth.open_flow(context, "handler", readonly=True) as flow: - actual_init_flow_state = flow.flow_state.copy() - flow.flow_state.id = "garbage" - flow.flow_state.tag = FlowStateTag.FAILURE - flow.flow_state.expires = 0 - flow.flow_state.attempts_remaining = -1 - actual_final_flow_state = await storage.read([key], FlowState)[key] + context = self.create_context(mocker, channel_id, user_id) + auth = Authorization(storage, connection_client, auth_handlers) + flow_storage_client = FlowStorageClient(context, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state # verify - expected_final_flow_state = baseline.read(key, FlowState) - assert actual_init_flow_state == expected_init_flow_state - assert actual_final_flow_state == expected_final_flow_state - assert await baseline.equals(storage) + actual_flow_state = await flow_storage_client.read(auth_handler_id) + assert actual_flow_state == expected_flow_state @pytest.mark.asyncio - @pytest.mark.parametrize( - "channel_id, from_property_id, auth_handler_id", - [ - ("channel_id", "user_id", "expired"), - ("teams_id", "Bob", "no_retries"), - ("channel", "Alice", "begin"), - ("channel", "Alice", "continue"), - ("channel", "Alice", "expired_and_retries"), - ("channel", "Alice", "not_started"), - ] - ) - async def test_open_flow_storage_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): + async def test_open_flow_not_in_storage( + self, + mocker, + storage, + connection_manager, + auth_handlers + ): # setup - storage = MemoryStorage(STORAGE_SAMPLE_DICT) - baseline = StorageBaseline(STORAGE_SAMPLE_DICT) - auth = Authorization( - storage, - connection_manager, - auth_handlers + context = self.create_context(mocker, "__channel_id", "__user_id") + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(context, storage) + + # test + async with auth.open_flow(context, "__auth_handler_id") as flow: + assert flow is not None + assert isinstance(flow, AuthFlow) + flow_state = await flow_storage_client.read("__auth_handler_id") + + # verify + assert flow_state.channel_id == "__channel_id" + assert flow_state.user_id == "__user_id" + assert flow_state.auth_handler_id == "__auth_handler_id" + assert flow_state.tag == FlowStateTag.NOT_STARTED + + @pytest.mark.asyncio + async def test_open_flow_success_modified_complete_flow( + self, + mocker, + storage, + connection_client, + auth_handlers, + auth_handler_id, + channel_id, + user_id + ): + # setup + channel_id = "teams" + user_id = "Alice" + auth_handler_id = "graph" + + self.create_user_token_client( + mocker, + get_token_return=TokenResponse(token=TEST_DEFAULTS.RES_TOKEN) ) - context = self.build_context(mocker, channel_id, from_property_id) - storage_client = FlowStorageClient(context, storage) - key = storage_client.key(auth_handler_id) - expected_init_flow_state = storage.read(key, FlowState) - - # code - async with auth.open_flow(context, "handler") as flow: - actual_init_flow_state = flow.flow_state.copy() - flow.flow_state.id = "garbage" - flow.flow_state.tag = FlowStateTag.FAILURE - flow.flow_state.expires = 0 - flow.flow_state.attempts_remaining = -1 + + context = self.create_context(mocker, channel_id, user_id) + context.activity.type = ActivityTypes.message + context.activity.text = "123456" + + auth = Authorization(storage, connection_client, auth_handlers) + flow_storage_client = FlowStorageClient(context, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + + flow_response = await flow.begin_or_continue_flow(context.activity) + res_flow_state = flow_response.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_failure( + self, + mocker, + baseline_storage, + storage, + connection_client, + auth_handlers, + auth_handler_id, + channel_id, + user_id + ): + # setup + channel_id = "webchat" + user_id = "Bob" + auth_handler_id = "graph" + + context = self.create_context(mocker, channel_id, user_id) + + auth = Authorization(storage, connection_client, auth_handlers) + flow_storage_client = FlowStorageClient(context, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + expected_flow_state.attempts_remaining = 0 + + flow_response = await flow.begin_or_continue_flow(context.activity) + res_flow_state = flow_response.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + + assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state + + baseline_storage.write(res_flow_state.model_copy()) + assert await baseline_storage.equals(storage) + + @pytest.mark.asyncio + async def test_open_flow_success_modified_signout( + self, + mocker, + storage, + connection_client, + auth_handlers, + auth_handler_id, + channel_id, + user_id + ): + # setup + channel_id = "webchat" + user_id = "Alice" + auth_handler_id = "graph" + + context = self.create_context(mocker, channel_id, user_id) + + auth = Authorization(storage, connection_client, auth_handlers) + flow_storage_client = FlowStorageClient(context, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + expected_flow_state.user_token = "" + + flow_response = await flow.sign_out() + res_flow_state = flow_response.flow_state # verify - baseline.write({ - "auth/channel/Alice/continue": flow.flow_state - }) - expected_final_flow_state = baseline.read(key, FlowState) - assert await baseline.equals(storage) - assert actual_init_flow_state == expected_init_flow_state - assert flow.flow_state == expected_final_flow_state + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + + assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state @pytest.mark.asyncio - async def test_get_token(self, mocker, m_storage): - m_storage.read.return_value = FlowState( - id="auth_handler", - tag=FlowStateTag.ACTIVE, - expires=3600, - attempts_remaining=3 + async def test_get_token_success( + self, + mocker, + auth + ): + mock_user_token_client_class = self.create_user_token_client( + mocker, + get_token_return=TokenResponse(token="token") ) - expected = TokenResponse( - access_token="access_token", - refresh_token="refresh_token", - expires_in=3600 + context = self.create_context(mocker, "__channel_id", "__user_id") + assert await auth.get_token(context, "auth_handler") == TokenResponse(token="token") + mock_user_token_client_class.get_user_token.called_once() + + @pytest.mark.asyncio + async def test_get_token_empty_response( + self, + mocker, + auth + ): + mock_user_token_client_class = self.create_user_token_client( + mocker, + get_token_return=TokenResponse() ) - mock_flow = mocker.AsyncMock() - mock_flow.get_user_token.return_value = expected - mocker.patch.object("OAuthFlow", "get_token", return_value=expected) - mocker.patch.object("OAuthFlow", "__init__", return_value=mock_flow) + context = self.create_context(mocker, "__channel_id", "__user_id") + assert await auth.get_token(context, "auth_handler") == TokenResponse() + mock_user_token_client_class.get_user_token.called_once() - assert await auth.get_token("auth_handler") is expected - assert mock_flow.get_user_token.called_once() + @pytest.mark.asyncio + async def test_get_token_error( + self, + turn_context, + storage, + connection_manager, + auth_handlers + ): + auth = Authorization(storage, connection_manager, auth_handlers) + with pytest.raises(ValueError): + await auth.get_token(turn_context, "missing-handler") + + @pytest.mark.asyncio + async def test_exchange_token_no_token( + self, + turn_context, + mock_auth_flow_class, + mocker, + auth + ): + mock_auth_flow_class.get_user_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse() @pytest.mark.asyncio @pytest.mark.parametrize( - "auth, context, auth_handler_id", + "token", [ - (lazy_fixture("auth"), lazy_fixture("context"), "missing-handler"), - (lazy_fixture("auth"), lazy_fixture("context"), ""), - (lazy_fixture("auth"), None, "handler") - ] + "token", + "" + ] # robrandao: TODOTODO ) - async def test_get_token_error(self, auth, context, auth_handler_id): - with pytest.raises(ValueError): - await auth.get_token(context, auth_handler_id) + async def test_exchange_token_not_exchangeable( + self, + mock_auth_flow_class, + turn_context, + mocker, + auth, + token + ): + mock_auth_flow_class.get_user_token = mocker.AsyncMock( + return_value=TokenResponse(token=token) + ) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse() @pytest.fixture def valid_token_response(self): From 7a48c90f31e711522448bb592fff75585ee862da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Wed, 20 Aug 2025 21:46:13 -0700 Subject: [PATCH 12/32] Adding documentation --- .../hosting/core/app/oauth/auth_flow.py | 54 +++++++++++++++--- .../hosting/core/app/oauth/authorization.py | 57 ++++++++++++++----- .../core/app/oauth/flow_storage_client.py | 2 +- 3 files changed, 88 insertions(+), 25 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py index 3115b384..ba16fbc0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py @@ -46,11 +46,13 @@ def __init__( ): """ Arguments: - abs_oauth_connection_name: + flow_state: The state of the flow. + user_token_client: The user token client to use for token operations. - user_token_client: - - flow_state: + Keyword Arguments: + flow_duration: The duration of the flow in milliseconds (default: 60000). + max_attempts: The maximum number of attempts for the flow + set when starting a flow (default: 3). """ raise_if_empty_or_None( self.__init__.__name__, @@ -94,10 +96,8 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: TokenResponse The user token response. - Notes - ----- - flow_state.user_token is updated with the latest token. - + Notes: + flow_state.user_token is updated with the latest token. """ token_response: TokenResponse = await self.__user_token_client.user_token.get_token( user_id=self.__user_id, @@ -110,7 +110,11 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: return token_response async def sign_out(self) -> None: - """Sign out the user.""" + """Sign out the user. + + Sets the flow state tag to NOT_STARTED + Resets the flow state user_token field + """ await self.__user_token_client.user_token.sign_out( user_id=self.__user_id, connection_name=self.__abs_oauth_connection_name, @@ -120,11 +124,23 @@ async def sign_out(self) -> None: self.__flow_state.tag = FlowStateTag.NOT_STARTED def __use_attempt(self) -> None: + """Decrements the remaining attempts for the flow, checking for failure.""" self.__flow_state.attempts_remaining -= 1 if self.__flow_state.attempts_remaining <= 0: self.__flow_state.tag = FlowStateTag.FAILURE async def begin_flow(self, activity: Activity) -> FlowResponse: + """Begins the OAuthFlow. + + Args: + activity: The activity that initiated the flow. + + Returns: + The response containing the flow state and sign-in resource if applicable. + + Notes: + The flow state is reset if a token is not obtained from cache. + """ # init flow state @@ -153,6 +169,7 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: return FlowResponse(flow_state=self.__flow_state, sign_in_resource=sign_in_resource) async def __continue_from_message(self, activity: Activity) -> tuple[TokenResponse, FlowErrorTag]: + """Handles the continuation of the flow from a message activity.""" magic_code: str = activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: token_response: TokenResponse = await self.get_user_token(magic_code) @@ -165,12 +182,14 @@ async def __continue_from_message(self, activity: Activity) -> tuple[TokenRespon return TokenResponse(), FlowErrorTag.MAGIC_FORMAT async def __continue_from_invoke_verify_state(self, activity: Activity) -> TokenResponse: + """Handles the continuation of the flow from an invoke activity for verifying state.""" token_verify_state = activity.value magic_code: str = token_verify_state.get("state") token_response: TokenResponse = await self.get_user_token(magic_code) return token_response async def __continue_from_invoke_token_exchange(self, activity: Activity) -> TokenResponse: + """Handles the continuation of the flow from an invoke activity for token exchange.""" token_exchange_request = activity.value token_response = await self.__user_token_client.user_token.exchange_token( user_id=self.__user_id, @@ -181,6 +200,15 @@ async def __continue_from_invoke_token_exchange(self, activity: Activity) -> Tok return token_response async def continue_flow(self, activity: Activity) -> FlowResponse: + """Continues the OAuth flow based on the incoming activity. + + Args: + activity: The incoming activity to continue the flow with. + + Returns: + A FlowResponse object containing the updated flow state and any token response. + + """ logger.debug("Continuing auth flow...") if not self.__flow_state.is_active(): @@ -218,6 +246,14 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: ) async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: + """Begins a new OAuth flow or continues an existing one based on the activity. + + Args: + activity: The incoming activity to begin or continue the flow with. + + Returns: + A FlowResponse object containing the updated flow state and any token response. + """ if self.__flow_state.is_active(): return await self.continue_flow(activity) else: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 0612bdb1..1169e0e5 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -109,6 +109,7 @@ def __init__( # ) def __check_for_ids(self, context: TurnContext): + """Checks for IDs necessary to load a new or existing flow.""" if ( not context.activity.channel_id or not context.activity.from_property or @@ -117,6 +118,20 @@ def __check_for_ids(self, context: TurnContext): raise ValueError("Channel ID and User ID are required") async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> tuple[AuthFlow, FlowStorageClient, FlowState]: + """Loads the OAuth flow for a specific auth handler. + + Args: + context: The context object for the current turn. + auth_handler_id: The ID of the auth handler to use. + + Returns: + The AuthFlow returned corresponds to the flow associated with the + chosen handler, and the channel and user info found in the context. + + The FlowStorageClient corresponds to the channel and user info. + The FlowState returned is the flow state for the given channel/user/handler + triple at the time of creating the flow. + """ user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) # robrandao: TODO auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) @@ -139,16 +154,17 @@ async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> tuple return flow, flow_storage_client, flow_state @asynccontextmanager - async def open_flow(self, context: TurnContext, auth_handler_id: str = "", readonly: bool = False) -> FlowResponse: - """ - Starts the OAuth flow for a specific auth handler. + async def open_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlow: + """Loads an Auth flow and saves changes the changes to storage if any are made. Args: context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + auth_handler_id: ID of the auth handler to use. - Returns: - The flow response from the OAuth provider. + Yields: + AuthFlow: + The AuthFlow instance loaded from storage or newly created + if not yet present in storage. """ if not context or not auth_handler_id: raise ValueError("context and auth_handler_id are required") @@ -156,17 +172,18 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str = "", reado flow, flow_storage_client, init_flow_state = self.__load_flow(context, auth_handler_id) yield flow - if not readonly and flow.flow_state != init_flow_state: - flow_storage_client.write(flow.flow_state) + new_flow_state = flow.flow_state + if new_flow_state != init_flow_state: + flow_storage_client.write(new_flow_state) async def get_token( - self, context: TurnContext, auth_handler_id: Optional[str] = None + self, context: TurnContext, auth_handler_id: str ) -> TokenResponse: """ Gets the token for a specific auth handler. Args: - context: The context object for the current turn. + context: The context object for the current turn. auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. Returns: @@ -271,7 +288,8 @@ async def __handle_obo( scopes=scopes, # Expiration can be set based on the token provider's response ) - async def get_active_flow_state(self, context: TurnContext, turn_state: TurnState = None) -> Optional[FlowState]: + async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: + """Gets the first active flow state for the current context.""" flow_storage_client = FlowStorageClient(context, self.__storage) for auth_handler_id in self.__auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) @@ -333,8 +351,16 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: async def __sign_out( self, context: TurnContext, - auth_handler_ids: Iterable[str] = None, + auth_handler_ids: Iterable[str], ) -> None: + """Signs out from the specified auth handlers. + + Args: + context: The context object for the current turn. + auth_handler_ids: List of auth handler IDs to sign out from. + + Deletes the associated flow states from storage. + """ for auth_handler_id in auth_handler_ids: flow, flow_storage_client, initial_flow_state = self.__load_flow(context, auth_handler_id) if initial_flow_state: @@ -345,7 +371,6 @@ async def __sign_out( async def sign_out( self, context: TurnContext, - _state: TurnState, auth_handler_id: Optional[str] = None, ) -> None: """ @@ -354,8 +379,10 @@ async def sign_out( Args: context: The context object for the current turn. - state: The state object for the current turn. - auth_handler_id: Optional ID of the auth handler to use for sign out. + auth_handler_id: Optional ID of the auth handler to use for sign out. If None, + signs out from all the handlers. + + Deletes the associated flow state(s) from storage. """ if auth_handler_id: self.__sign_out(context, [auth_handler_id]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py index d61e4227..52e49156 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py @@ -41,9 +41,9 @@ def __init__( @property def base_key(self) -> str: + """Returns the prefix used for flow state storage isolation.""" return self.__base_key - @staticmethod def key(self, flow_id: str) -> str: """Creates a storage key for a specific sign-in handler.""" return f"{self.__base_key}{flow_id}" From 338aed9a5aaa49485c908fab412da493c97326e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Thu, 21 Aug 2025 05:14:41 -0700 Subject: [PATCH 13/32] Added more Authorizationt tests and fixed inconsistencies in storage client --- .../hosting/core/app/oauth/authorization.py | 67 +++---- .../agents/hosting/core/app/oauth/utils.py | 2 - .../tests/flow_storage_client_test.py | 170 ------------------ .../tests/test_authorization.py | 147 +++++---------- .../tests/test_flow_storage_client.py | 150 ++++++++++++++++ 5 files changed, 229 insertions(+), 307 deletions(-) delete mode 100644 libraries/microsoft-agents-hosting-core/tests/flow_storage_client_test.py create mode 100644 libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 1169e0e5..b14a3e00 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging import jwt -from typing import Dict, Optional, Callable, Awaitable +from typing import Dict, Optional, Callable, Awaitable, AsyncIterator from collections.abc import Iterable from contextlib import asynccontextmanager @@ -97,7 +97,7 @@ def __init__( # # Create OAuth flow with configuration # messages_config = {} # if auth_handler.title: - # messages_config["card_title"] = auth_handler.title + # ["card_title"] = auth_handler.title # if auth_handler.text: # messages_config["button_text"] = auth_handler.text @@ -108,16 +108,25 @@ def __init__( # messages_configuration=messages_config if messages_config else None, # ) - def __check_for_ids(self, context: TurnContext): - """Checks for IDs necessary to load a new or existing flow.""" + def __ids_from_context(self, context: TurnContext) -> tuple[str, str]: + """Checks and returns IDs necessary to load a new or existing flow. + + Raises a ValueError if channel ID or user ID are missing. + """ if ( not context.activity.channel_id or not context.activity.from_property or not context.activity.from_property.id ): raise ValueError("Channel ID and User ID are required") + + return context.activity.channel_id, context.activity.from_property.id - async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> tuple[AuthFlow, FlowStorageClient, FlowState]: + async def __load_flow( + self, + context: TurnContext, + auth_handler_id: str = "" + ) -> tuple[AuthFlow, FlowStorageClient, FlowState]: """Loads the OAuth flow for a specific auth handler. Args: @@ -133,12 +142,14 @@ async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> tuple triple at the time of creating the flow. """ user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) # robrandao: TODO - auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) - self.__check_for_ids(context) - channel_id = context.activity.channel_id - user_id = context.activity.from_property.id + # resolve handler id + auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) + auth_handler_id = auth_handler.id + + channel_id, user_id = self.__ids_from_context(context) + # try to load existing state flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) flow_state: FlowState = await flow_storage_client.read(auth_handler_id) @@ -154,7 +165,7 @@ async def __load_flow(self, context: TurnContext, auth_handler_id: str) -> tuple return flow, flow_storage_client, flow_state @asynccontextmanager - async def open_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlow: + async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[AuthFlow]: """Loads an Auth flow and saves changes the changes to storage if any are made. Args: @@ -166,12 +177,13 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str) -> AuthFlo The AuthFlow instance loaded from storage or newly created if not yet present in storage. """ - if not context or not auth_handler_id: - raise ValueError("context and auth_handler_id are required") - + if not context: + raise ValueError("context is required") + flow, flow_storage_client, init_flow_state = self.__load_flow(context, auth_handler_id) yield flow + # persist state new_flow_state = flow.flow_state if new_flow_state != init_flow_state: flow_storage_client.write(new_flow_state) @@ -229,7 +241,7 @@ async def exchange_token( # return token_response - def __is_exchangeable(self, token: Optional[str]) -> bool: + def __is_exchangeable(self, token: str) -> bool: """ Checks if a token is exchangeable (has api:// audience). @@ -239,9 +251,6 @@ def __is_exchangeable(self, token: Optional[str]) -> bool: Returns: True if the token is exchangeable, False otherwise. """ - if not token: - return False - try: # Decode without verification to check the audience payload = jwt.decode(token, options={"verify_signature": False}) @@ -264,33 +273,27 @@ async def __handle_obo( Returns: The new token response. + """ - if not self.__connection_manager: - logger.error("Connection manager is not configured", stack_info=True) - raise ValueError("Connection manager is not configured") - - auth_handler = self.resolver_handler(handler_id) - if auth_handler.flow is None: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") - - # Use the flow's OBO method to exchange the token - token_provider: AccessTokenProviderBase = ( - self.__connection_manager.get_connection(auth_handler.obo_connection_name) + auth_handler = self.resolve_handler(handler_id) + token_provider: AccessTokenProviderBase = self.__connection_manager.get_connection( + auth_handler.obo_connection_name ) + logger.info("Attempting to exchange token on behalf of user") - token = await token_provider.aquire_token_on_behalf_of( + new_token = await token_provider.aquire_token_on_behalf_of( scopes=scopes, user_assertion=token, ) return TokenResponse( - token=token, + token=new_token, scopes=scopes, # Expiration can be set based on the token provider's response ) async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: """Gets the first active flow state for the current context.""" - flow_storage_client = FlowStorageClient(context, self.__storage) + channel_id, user_id = self.__ids_from_context(context) + flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) for auth_handler_id in self.__auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) if flow_state.is_active(): diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py index 40fb8762..ad4af7ba 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py @@ -1,5 +1,3 @@ -import inspect - def raise_if_empty_or_None(func_name, err=ValueError, **kwargs): s = "" for key, value in kwargs.items(): diff --git a/libraries/microsoft-agents-hosting-core/tests/flow_storage_client_test.py b/libraries/microsoft-agents-hosting-core/tests/flow_storage_client_test.py deleted file mode 100644 index 0c2a593b..00000000 --- a/libraries/microsoft-agents-hosting-core/tests/flow_storage_client_test.py +++ /dev/null @@ -1,170 +0,0 @@ -import pytest -from unittest.mock import sentinel - -from microsoft.agents.hosting.core.storage import MemoryStorage -from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem -from microsoft.agents.hosting.core.app.oauth import ( - FlowState, - FlowStorageClient, -) - -class TestFlowStorageClient: - - @pytest.fixture - def turn_context(self, mocker): - context = mocker.Mock() - context.activity.channel_id = "__channel_id" - context.activity.from_property.id = "__user_id" - return context - - @pytest.fixture - def storage(self): - return MemoryStorage() - - @pytest.fixture - def client(self, turn_context, storage): - return FlowStorageClient(turn_context, storage) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "channel_id, from_property_id", - [ - ("channel_id", "from_property_id"), - ("teams_id", "Bob"), - ("channel", "Alice"), - ], - ) - async def test_init_base_key(self, mocker, channel_id, from_property_id): - context = mocker.Mock() - context.activity.channel_id = channel_id - context.activity.from_property.id = from_property_id - client = FlowStorageClient(context, mocker.Mock()) - assert client.base_key == f"auth/{channel_id}/{from_property_id}/" - - @pytest.mark.asyncio - async def test_init_fails_without_from_id(self, mocker, storage): - with pytest.raises(ValueError): - context = mocker.Mock() - context.activity.channel_id = "channel_id" - context.activity.from_property = mocker.Mock(id=None) - FlowStorageClient(context, storage) - - @pytest.mark.asyncio - async def test_init_fails_without_channel_id(self, mocker, storage): - with pytest.raises(ValueError): - context = mocker.Mock() - context.activity.channel_id = None - context.activity.from_property.id = "from_id" - FlowStorageClient(context, storage) - - @pytest.mark.parametrize( - "auth_handler_id, expected", - [ - ("handler", "auth/__channel_id/__user_id/handler"), - ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), - ] - ) - def test_key(self, client, auth_handler_id, expected): - assert client.key(auth_handler_id) == expected - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", - [ - ("handler",), - ("auth_handler",), - ] - ) - async def test_read(self, mocker, turn_context, auth_handler_id): - storage = mocker.AsyncMock() - key = f"auth/__channel_id/__user_id/{auth_handler_id}" - storage.read.return_value = {key: FlowState()} - client = FlowStorageClient(turn_context, storage) - res = await client.read(auth_handler_id) - assert res is storage.read.return_value[key] - storage.read.assert_called_once_with([f"auth/__channel_id/__user_id/{auth_handler_id}"], FlowState) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id, key", - [ - ("handler", "auth/__channel_id/__user_id/handler"), - ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), - ] - ) - async def test_write(self, mocker, turn_context, auth_handler_id, key): - storage = mocker.AsyncMock() - storage.write.return_value = None - client = FlowStorageClient(turn_context, storage) - flow_state = mocker.Mock(spec=FlowState) - flow_state.flow_id = auth_handler_id - await client.write(flow_state) - storage.write.assert_called_once_with({ key: flow_state }) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id, key", - [ - ("handler", "auth/__channel_id/__user_id/handler"), - ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), - ] - ) - async def test_delete(self, mocker, turn_context, auth_handler_id, key): - storage = mocker.AsyncMock() - storage.delete.return_value = None - client = FlowStorageClient(turn_context, storage) - await client.delete(auth_handler_id) - storage.delete.assert_called_once_with([client.key(auth_handler_id)]) - - @pytest.mark.asyncio - async def test_integration_with_memory_storage(self, turn_context): - - flow_state_alpha = FlowState(flow_id="handler", flow_started=True) - flow_state_beta = FlowState(flow_id="auth_handler", flow_started=True, user_token="token") - - storage = MemoryStorage({ - "some_data": MockStoreItem({"value": "test"}), - "auth/__channel_id/__user_id/handler": flow_state_alpha, - "auth/__channel_id/__user_id/auth_handler": flow_state_beta, - }) - baseline = MemoryStorage({ - "some_data": MockStoreItem({"value": "test"}), - "auth/__channel_id/__user_id/handler": flow_state_alpha, - "auth/__channel_id/__user_id/auth_handler": flow_state_beta, - }) - - # helpers - async def read_check(*args, **kwargs): - res_storage = await storage.read(*args, **kwargs) - res_baseline = await baseline.read(*args, **kwargs) - assert res_storage == res_baseline - - async def write_both(*args, **kwargs): - await storage.write(*args, **kwargs) - await baseline.write(*args, **kwargs) - - async def delete_both(*args, **kwargs): - await storage.delete(*args, **kwargs) - await baseline.delete(*args, **kwargs) - - client = FlowStorageClient(turn_context, storage) - - new_flow_state_alpha = FlowState(flow_id="handler") - flow_state_chi = FlowState(flow_id="chi") - - await client.write(new_flow_state_alpha) - await client.write(flow_state_chi) - await baseline.write({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.model_copy()}) - await baseline.write({"auth/__channel_id/__user_id/chi": flow_state_chi.model_copy()}) - - await write_both({"auth/__channel_id/__user_id/handler": new_flow_state_alpha.model_copy()}) - await write_both({"auth/__channel_id/__user_id/auth_handler": flow_state_beta.model_copy()}) - await write_both({"other_data": MockStoreItem({"value": "more"})}) - - await delete_both(["some_data"]) - - await read_check(["auth/__channel_id/__user_id/handler"], target_cls=FlowState) - await read_check(["auth/__channel_id/__user_id/auth_handler"], target_cls=FlowState) - await read_check(["auth/__channel_id/__user_id/chi"], target_cls=FlowState) - await read_check(["other_data"], target_cls=MockStoreItem) - await read_check(["some_data"], target_cls=MockStoreItem) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index 7ed13c58..0291a413 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -1,4 +1,4 @@ -import datetime +import jwt import pytest @@ -6,34 +6,30 @@ ActivityTypes, TokenResponse ) -from microsoft.agents.hosting.core import ( +from microsoft.agents.hosting.core.app.oauth import ( Authorization, - MemoryStorage, FlowStorageClient, - FlowState, FlowErrorTag, FlowStateTag, - FlowResponse, - storage ) +from microsoft.agents.hosting.core import MemoryStorage from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow -from tools.oauth_test_utils import ( +from .tools.oauth_test_env import ( TEST_DEFAULTS, STORAGE_INIT_DATA ) -from tools.testing_authorization import ( - TestingTokenProvider, +from .tools.testing_authorization import ( TestingConnectionManager, create_test_auth_handler ) -class TestAuthFlowUtils: +class TestUtils: def create_context(self, mocker, @@ -69,10 +65,10 @@ def create_mock_user_token_client( @pytest.fixture def mock_user_token_client_class(self, mocker): - return self.create_mock_user_token_client_class(mocker) + return self.create_mock_user_token_client(mocker) @pytest.fixture - def mock_flow_class(self, mocker): + def mock_auth_flow_class(self, mocker): mock_flow_class = mocker.Mock(spec=AuthFlow) mocker.patch.object(AuthFlow, "__init__", return_value=mock_flow_class) @@ -102,43 +98,6 @@ def create_user_token_client(self, mocker, get_token_return=None): @pytest.fixture def user_token_client(self, mocker): return self.create_user_token_client(mocker, get_token_return=TEST_DEFAULTS.RES_TOKEN) - - def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): - # def conv_ref(): - # return mocker.MagicMock(spec=ConversationReference) - mock_conversation_ref = mocker.MagicMock(ConversationReference) - mocker.patch.object(Activity, "get_conversation_reference", return_value=mocker.MagicMock(ConversationReference)) - # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) - return Activity( - type=activity_type, - name=name, - from_property=ChannelAccount(id=TEST_DEFAULTS.USER_ID), - channel_id=TEST_DEFAULTS.CHANNEL_ID, - # get_conversation_reference=mocker.Mock(return_value=conv_ref), - relates_to=mocker.MagicMock(ConversationReference), - value=value, - text=text - ) - - @pytest.fixture(params=TEST_DEFAULTS.ALL()) - def sample_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture(params=TEST_DEFAULTS.FAILED()) - def sample_failed_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture(params=TEST_DEFAULTS.INACTIVE()) - def sample_inactive_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture(params=TEST_DEFAULTS.ACTIVE()) - def sample_active_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture - def flow(self, sample_flow_state, user_token_client): - return AuthFlow(sample_flow_state, user_token_client) @pytest.fixture def auth_handlers(self): @@ -157,7 +116,7 @@ def connection_manager(self): def auth(self, connection_manager, storage, auth_handlers): return Authorization(connection_manager, storage, auth_handlers) -class TestAuthorizationUtils: +class TestAuthorizationUtils(TestUtils): def create_user_token_client(self, mocker, get_token_return=None): @@ -184,8 +143,6 @@ def storage(self): def baseline_storage(self): return StorageBaseline(STORAGE_INIT_DATA()) - def mock_user_token_provider - def patch_flow(self, mocker, flow_response=None, token=None,): mocker.patch.object(AuthFlow, "get_user_token", return_value=TokenResponse(token=token)) mocker.patch.object(AuthFlow, "sign_out") @@ -227,7 +184,7 @@ def test_init_configuration_variants(self,storage, connection_manager, auth_hand @pytest.mark.asyncio @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", [ - ["", "webchat", "Alice"] + ["", "webchat", "Alice"], ["handler", "teams", "Bob"] ]) async def test_open_flow_value_error( @@ -251,6 +208,7 @@ async def test_open_flow_value_error( ]) async def test_open_flow_readonly( self, + mocker, storage, connection_client, auth_handlers, @@ -261,7 +219,7 @@ async def test_open_flow_readonly( # setup context = self.create_context(mocker, channel_id, user_id) auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(context, storage) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) # test async with auth.open_flow(context, auth_handler_id) as flow: @@ -282,7 +240,7 @@ async def test_open_flow_not_in_storage( # setup context = self.create_context(mocker, "__channel_id", "__user_id") auth = Authorization(storage, connection_manager, auth_handlers) - flow_storage_client = FlowStorageClient(context, storage) + flow_storage_client = FlowStorageClient("__channel_id", "__user_id", storage) # test async with auth.open_flow(context, "__auth_handler_id") as flow: @@ -322,7 +280,7 @@ async def test_open_flow_success_modified_complete_flow( context.activity.text = "123456" auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(context, storage) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) # test async with auth.open_flow(context, auth_handler_id) as flow: @@ -360,7 +318,7 @@ async def test_open_flow_success_modified_failure( context = self.create_context(mocker, channel_id, user_id) auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(context, storage) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) # test async with auth.open_flow(context, auth_handler_id) as flow: @@ -401,7 +359,7 @@ async def test_open_flow_success_modified_signout( context = self.create_context(mocker, channel_id, user_id) auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(context, storage) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) # test async with auth.open_flow(context, auth_handler_id) as flow: @@ -416,7 +374,7 @@ async def test_open_flow_success_modified_signout( actual_flow_state = await flow_storage_client.read(auth_handler_id) expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now - assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT + assert flow_response.flow_error_tag == FlowErrorTag.NONE assert res_flow_state == expected_flow_state assert actual_flow_state == expected_flow_state @@ -475,13 +433,6 @@ async def test_exchange_token_no_token( assert res == TokenResponse() @pytest.mark.asyncio - @pytest.mark.parametrize( - "token", - [ - "token", - "" - ] # robrandao: TODOTODO - ) async def test_exchange_token_not_exchangeable( self, mock_auth_flow_class, @@ -490,41 +441,28 @@ async def test_exchange_token_not_exchangeable( auth, token ): + token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") mock_auth_flow_class.get_user_token = mocker.AsyncMock( - return_value=TokenResponse(token=token) + return_value=TokenResponse(connection_name="github", token=token) ) res = await auth.exchange_token(turn_context, ["scope"], "github") assert res == TokenResponse() - - @pytest.fixture - def valid_token_response(self): - return TokenResponse( - connection_name="connection", - token="token" - ) - - @pytest.fixture - def invalid_exchange_token(self): - token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - return TokenResponse( - connection_name="connection" - token=token - ) @pytest.mark.asyncio - async def test_exchange_token( + async def test_exchange_token_valid_exchangeable( self, - mock_user_token_client_class, + mock_auth_flow_class, + turn_context, + mocker, + auth, + token ): - - mocker.patch.object("OAuthFlow", - get_user_token=mocker.AsyncMock(return_value=TokenResponse( - access_token="access_token", - refresh_token="refresh_token", - expires_in=3600 - )) + token = jwt.encode({"aud": "valid://botframework.test.api"}, "") + mock_auth_flow_class.get_user_token = mocker.AsyncMock( + return_value=TokenResponse(connection_name="github", token=token) ) - mock_user_token_client_class + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse(scopes=["scope"], token=token, connection_name="github") @pytest.mark.asyncio @pytest.mark.parametrize( @@ -570,34 +508,36 @@ def test_resolve_handler_first(self, auth, auth_handlers_list): @pytest.mark.asyncio async def test_sign_out_individual( self, + mocker, mock_user_token_client_class, - mock_flow_class, - turn_context, + mock_auth_flow_class, storage, baseline_storage, connection_manager, auth_handlers ): # setup - storage_client = FlowStorageClient(turn_context, storage) + storage_client = FlowStorageClient("teams", "Alice", storage) + context = self.create_context(mocker, "teams", "Alice") auth = Authorization(storage, connection_manager, auth_handlers) - await auth.sign_out("handler") + await auth.sign_out(context, "graph") - await baseline_storage.delete([storage_client.key("handler")]) + await baseline_storage.delete([storage_client.key("graph")]) # verify storage assert await baseline_storage.equals(storage) # verify flow - mock_flow_class.sign_out.assert_called_once_with("handler") + mock_auth_flow_class.sign_out.assert_called_once_with("graph") mock_user_token_client_class.user_token.sign_out.assert_called_once() @pytest.mark.asyncio async def test_sign_out_all( self, + mocker, mock_user_token_client_class, - mock_flow_class, + mock_auth_flow_class, turn_context, storage, baseline_storage, @@ -605,10 +545,11 @@ async def test_sign_out_all( auth_handlers ): # setup - storage_client = FlowStorageClient(turn_context, storage) + storage_client = FlowStorageClient("webchat", "Alice", storage) auth = Authorization(storage, connection_manager, auth_handlers) - await auth.sign_out("handler") + context = self.create_context(mocker, "webchat", "Alice") + await auth.sign_out(context) await baseline_storage.delete([storage_client.key("handler"), storage_client.key("connection")]) @@ -616,8 +557,8 @@ async def test_sign_out_all( assert await baseline_storage.equals(storage) # verify flow - mock_flow_class.sign_out.assert_called_once_with("handler") - mock_flow_class.sign_out.assert_called_once_with("connection") + mock_auth_flow_class.sign_out.assert_called_once_with("handler") + mock_auth_flow_class.sign_out.assert_called_once_with("connection") # robrandao: TODO -> handlers \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py new file mode 100644 index 00000000..57a65258 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py @@ -0,0 +1,150 @@ +import pytest +from unittest.mock import sentinel + +from microsoft.agents.hosting.core.storage import MemoryStorage +from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem +from microsoft.agents.hosting.core.app.oauth import ( + FlowState, + FlowStorageClient, +) + +class TestFlowStorageClient: + + @pytest.fixture + def channel_id(self): + return "__channel_id" + + @pytest.fixture + def user_id(self): + return "__user_id" + + @pytest.fixture + def storage(self): + return MemoryStorage() + + @pytest.fixture + def client(self, channel_id, user_id, storage): + return FlowStorageClient(channel_id, user_id, storage) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "channel_id, from_property_id", + [ + ("channel_id", "from_property_id"), + ("teams_id", "Bob"), + ("channel", "Alice"), + ], + ) + async def test_init_base_key(self, mocker, channel_id, user_id): + client = FlowStorageClient(channel_id, user_id, mocker.Mock()) + assert client.base_key == f"auth/{channel_id}/{user_id}/" + + @pytest.mark.asyncio + async def test_init_fails_without_user_id(self, channel_id, storage): + with pytest.raises(ValueError): + FlowStorageClient(channel_id, "", storage) + + @pytest.mark.asyncio + async def test_init_fails_without_channel_id(self, user_id, storage): + with pytest.raises(ValueError): + FlowStorageClient("", user_id, storage) + + @pytest.mark.parametrize( + "auth_handler_id, expected", + [ + ("handler", "auth/__channel_id/__user_id/handler"), + ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ] + ) + def test_key(self, client, auth_handler_id, expected): + assert client.key(auth_handler_id) == expected + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id", ["handler", "auth_handler"] + ) + async def test_read(self, mocker, user_id, channel_id, auth_handler_id): + storage = mocker.AsyncMock() + key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" + storage.read.return_value = {key: FlowState()} + client = FlowStorageClient(channel_id, user_id, storage) + res = await client.read(auth_handler_id) + assert res is storage.read.return_value[key] + storage.read.assert_called_once_with([client.key(auth_handler_id)], FlowState) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id", ["handler", "auth_handler"] + ) + async def test_write(self, mocker, channel_id, user_id, auth_handler_id): + storage = mocker.AsyncMock() + storage.write.return_value = None + client = FlowStorageClient(channel_id, user_id, storage) + flow_state = mocker.Mock(spec=FlowState) + flow_state.flow_id = auth_handler_id + await client.write(flow_state) + storage.write.assert_called_once_with({ client.key(auth_handler_id): flow_state }) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id", ["handler", "auth_handler"] + ) + async def test_delete(self, mocker, channel_id, user_id, auth_handler_id): + storage = mocker.AsyncMock() + storage.delete.return_value = None + client = FlowStorageClient(channel_id, user_id, storage) + await client.delete(auth_handler_id) + storage.delete.assert_called_once_with([client.key(auth_handler_id)]) + + @pytest.mark.asyncio + async def test_integration_with_memory_storage(self, channel_id, user_id): + + flow_state_alpha = FlowState(flow_id="handler", flow_started=True) + flow_state_beta = FlowState(flow_id="auth_handler", flow_started=True, user_token="token") + + storage = MemoryStorage({ + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + }) + baseline = MemoryStorage({ + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + "fauth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + }) + + # helpers + async def read_check(*args, **kwargs): + res_storage = await storage.read(*args, **kwargs) + res_baseline = await baseline.read(*args, **kwargs) + assert res_storage == res_baseline + + async def write_both(*args, **kwargs): + await storage.write(*args, **kwargs) + await baseline.write(*args, **kwargs) + + async def delete_both(*args, **kwargs): + await storage.delete(*args, **kwargs) + await baseline.delete(*args, **kwargs) + + client = FlowStorageClient(channel_id, user_id, storage) + + new_flow_state_alpha = FlowState(flow_id="handler") + flow_state_chi = FlowState(flow_id="chi") + + await client.write(new_flow_state_alpha) + await client.write(flow_state_chi) + await baseline.write({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) + await baseline.write({f"auth/{channel_id}/{user_id}/chi": flow_state_chi.model_copy()}) + + await write_both({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) + await write_both({f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta.model_copy()}) + await write_both({"other_data": MockStoreItem({"value": "more"})}) + + await delete_both(["some_data"]) + + await read_check([f"auth/{channel_id}/{user_id}/handler"], target_cls=FlowState) + await read_check([f"auth/{channel_id}/{user_id}/auth_handler"], target_cls=FlowState) + await read_check([f"auth/{channel_id}/{user_id}/chi"], target_cls=FlowState) + await read_check(["other_data"], target_cls=MockStoreItem) + await read_check(["some_data"], target_cls=MockStoreItem) From d4e7f72829c858fe435c9e38ad43f0203ed7d8ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Thu, 21 Aug 2025 09:03:02 -0700 Subject: [PATCH 14/32] Renaming, restructing, and revising imports --- .../microsoft/agents/hosting/core/__init__.py | 23 +- .../agents/hosting/core/app/__init__.py | 10 +- .../hosting/core/app/agent_application.py | 5 +- .../agents/hosting/core/app/auth/__init__.py | 13 + .../core/app/{oauth => auth}/auth_handler.py | 3 + .../core/app/{oauth => auth}/authorization.py | 33 +- .../tests/__authorization_test.py | 0 .../app/{oauth => auth}/tests/conftest.py | 0 .../agents/hosting/core/app/oauth/__init__.py | 25 - .../agents/hosting/core/app/oauth/conftest.py | 3 - .../agents/hosting/core/app/oauth/utils.py | 8 - .../agents/hosting/core/oauth/__init__.py | 17 + .../oauth/models.py => oauth/flow_state.py} | 31 +- .../{app => }/oauth/flow_storage_client.py | 20 +- .../auth_flow.py => oauth/oauth_flow.py} | 28 +- .../agents/hosting/core/oauth_flow.py | 442 ------------------ .../tests/test_auth_flow.py | 109 ++--- .../tests/test_authorization.py | 23 +- .../{models_test.py => test_flow_state.py} | 6 +- .../tests/test_flow_storage_client.py | 6 +- .../{oauth_test_env.py => testing_oauth.py} | 117 ++--- 21 files changed, 237 insertions(+), 685 deletions(-) create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/__init__.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{oauth => auth}/auth_handler.py (93%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{oauth => auth}/authorization.py (94%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{oauth => auth}/tests/__authorization_test.py (100%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{oauth => auth}/tests/conftest.py (100%) delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py create mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/{app/oauth/models.py => oauth/flow_state.py} (71%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/{app => }/oauth/flow_storage_client.py (80%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/{app/oauth/auth_flow.py => oauth/oauth_flow.py} (93%) delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py rename libraries/microsoft-agents-hosting-core/tests/{models_test.py => test_flow_state.py} (91%) rename libraries/microsoft-agents-hosting-core/tests/tools/{oauth_test_env.py => testing_oauth.py} (56%) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py index 5d95f4e6..59ef6e1c 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py @@ -1,6 +1,5 @@ from .activity_handler import ActivityHandler from .agent import Agent -from .oauth_flow import OAuthFlow from .card_factory import CardFactory from .channel_adapter import ChannelAdapter from .channel_api_handler_protocol import ChannelApiHandlerProtocol @@ -20,12 +19,11 @@ from .app.route import Route, RouteHandler from .app.typing_indicator import TypingIndicator -# OAuth -from .app.oauth.authorization import ( +# App Auth +from .app.auth import ( Authorization, AuthorizationHandlers, AuthHandler, - SignInState, ) # App State @@ -44,6 +42,16 @@ from .authorization.jwt_token_validator import JwtTokenValidator from .authorization.auth_types import AuthTypes +#OAuth +from .oauth import ( + FlowState, + FlowStateTag, + FlowErrorTag, + FlowResponse, + FlowStorageClient, + OAuthFlow +) + # Client API from .client.agent_conversation_reference import AgentConversationReference from .client.channel_factory_protocol import ChannelFactoryProtocol @@ -88,7 +96,6 @@ __all__ = [ "ActivityHandler", "Agent", - "OAuthFlow", "CardFactory", "ChannelAdapter", "ChannelApiHandlerProtocol", @@ -155,4 +162,10 @@ "StoreItem", "Storage", "MemoryStorage", + "FlowState", + "FlowStateTag", + "FlowErrorTag", + "FlowResponse", + "FlowStorageClient", + "OAuthFlow" ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py index 124de7c5..67d1df98 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py @@ -13,12 +13,11 @@ from .route import Route, RouteHandler from .typing_indicator import TypingIndicator -# OAuth -from .oauth.authorization import ( +# Auth +from .auth import ( Authorization, - AuthorizationHandlers, AuthHandler, - SignInState, + AuthorizationHandlers, ) # App State @@ -49,7 +48,6 @@ "TurnState", "TempState", "Authorization", - "AuthorizationHandlers", "AuthHandler", - "SignInState", + "AuthorizationHandlers", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 6dcf6449..13d3be94 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -46,13 +46,12 @@ from .route import Route, RouteHandler from .state import TurnState from ..channel_service_adapter import ChannelServiceAdapter -from .oauth import ( - Authorization, +from ..oauth import ( FlowResponse, FlowState, FlowStateTag, - FlowErrorTag ) +from .auth import Authorization from .typing_indicator import TypingIndicator logger = logging.getLogger(__name__) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/__init__.py new file mode 100644 index 00000000..c964ae2f --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/__init__.py @@ -0,0 +1,13 @@ +from .authorization import ( + Authorization +) +from .auth_handler import ( + AuthHandler, + AuthorizationHandlers +) + +__all__ = [ + "Authorization", + "AuthHandler", + "AuthorizationHandlers", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/auth_handler.py similarity index 93% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/auth_handler.py index 203cb08c..8ad53bce 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/auth_handler.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + import logging from typing import Dict diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/authorization.py similarity index 94% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/authorization.py index b14a3e00..10f5b17d 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/authorization.py @@ -13,21 +13,20 @@ AccessTokenProviderBase, ) from microsoft.agents.hosting.core.storage import Storage -from microsoft.agents.activity import TokenResponse, Activity -from microsoft.agents.hosting.core.storage import StoreItem +from microsoft.agents.activity import TokenResponse from microsoft.agents.hosting.core.connector.client import UserTokenClient -from pydantic import BaseModel from ...turn_context import TurnContext -from ...app.state.turn_state import TurnState -# from ...oauth_flow import AuthFlow -from ...state.user_state import UserState +from ...oauth import ( + OAuthFlow, + FlowResponse, + FlowState, + FlowStateTag, + FlowStorageClient +) +from ..state.turn_state import TurnState from .auth_handler import AuthHandler -from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag -from .flow_storage_client import FlowStorageClient -from .auth_flow import AuthFlow - logger = logging.getLogger(__name__) @@ -126,7 +125,7 @@ async def __load_flow( self, context: TurnContext, auth_handler_id: str = "" - ) -> tuple[AuthFlow, FlowStorageClient, FlowState]: + ) -> tuple[OAuthFlow, FlowStorageClient, FlowState]: """Loads the OAuth flow for a specific auth handler. Args: @@ -134,7 +133,7 @@ async def __load_flow( auth_handler_id: The ID of the auth handler to use. Returns: - The AuthFlow returned corresponds to the flow associated with the + The OAuthFlow returned corresponds to the flow associated with the chosen handler, and the channel and user info found in the context. The FlowStorageClient corresponds to the channel and user info. @@ -161,20 +160,20 @@ async def __load_flow( abs_oauth_connection_name=auth_handler.abs_oauth_connection_name ) - flow = AuthFlow(flow_state, user_token_client) + flow = OAuthFlow(flow_state, user_token_client) return flow, flow_storage_client, flow_state @asynccontextmanager - async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[AuthFlow]: - """Loads an Auth flow and saves changes the changes to storage if any are made. + async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[OAuthFlow]: + """Loads an OAuth flow and saves changes the changes to storage if any are made. Args: context: The context object for the current turn. auth_handler_id: ID of the auth handler to use. Yields: - AuthFlow: - The AuthFlow instance loaded from storage or newly created + OAuthFlow: + The OAuthFlow instance loaded from storage or newly created if not yet present in storage. """ if not context: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/__authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/__authorization_test.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/__authorization_test.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/__authorization_test.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/conftest.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/conftest.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/tests/conftest.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/conftest.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py deleted file mode 100644 index 5817c861..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .authorization import ( - Authorization -) -from .auth_handler import ( - AuthHandler, - AuthorizationHandlers -) -from .models import ( - FlowState, - FlowStateTag, - FlowErrorTag, - FlowResponse, -) -from .flow_storage_client import FlowStorageClient - -__all__ = [ - "Authorization", - "AuthHandler", - "AuthorizationHandlers", - "FlowState", - "FlowStateTag", - "FlowErrorTag", - "FlowResponse", - "FlowStorageClient", -] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py deleted file mode 100644 index aa08436e..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/conftest.py +++ /dev/null @@ -1,3 +0,0 @@ -def turn_context(): - - context = TurnContext() \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py deleted file mode 100644 index ad4af7ba..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -def raise_if_empty_or_None(func_name, err=ValueError, **kwargs): - s = "" - for key, value in kwargs.items(): - if not value: - s += f"\tArgument '{key}' is required and cannot be None or empty.\n" - if s: - header = f"{func_name}: called with empty arguments:" - raise err(header + "\n" + s) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py new file mode 100644 index 00000000..e2db50a2 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py @@ -0,0 +1,17 @@ +from .flow_state import ( + FlowState, + FlowStateTag, + FlowErrorTag, + FlowResponse +) +from .flow_storage_client import FlowStorageClient +from .oauth_flow import OAuthFlow + +__all__ = [ + "FlowState", + "FlowStateTag", + "FlowErrorTag", + "FlowResponse", + "FlowStorageClient", + "OAuthFlow" +] \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py similarity index 71% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index 08c36b29..de3a3c6e 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/models.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -1,14 +1,24 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + from datetime import datetime from enum import Enum from typing import Optional from pydantic import BaseModel -from pydantic.types import PositiveInt -from microsoft.agents.activity import Activity, SignInResource, TokenResponse -from microsoft.agents.hosting.core.storage import StoreItem +from microsoft.agents.activity import Activity + +from ..storage import StoreItem class FlowStateTag(Enum): + """Represents the top-level state of an OAuthFlow + + For instance, a flow can arrive at an error, but its + broader state may still be CONTINUE if the flow can + still progress + """ + BEGIN = "begin" CONTINUE = "continue" NOT_STARTED = "not_started" @@ -16,12 +26,14 @@ class FlowStateTag(Enum): COMPLETE = "complete" class FlowErrorTag(Enum): + """Represents the various error states that can occur during an OAuthFlow""" NONE = "none" MAGIC_FORMAT = "magic_format" MAGIC_CODE_INCORRECT = "magic_code_incorrect" OTHER = "OTHER" class FlowState(BaseModel, StoreItem): + """Represents the state of an OAuthFlow""" flow_id: str = "" # robrandao: TODO user_token: str = "" @@ -34,10 +46,6 @@ class FlowState(BaseModel, StoreItem): attempts_remaining: int = 0 tag: FlowStateTag = FlowStateTag.NOT_STARTED - def refresh(self) -> None: - if self.is_expired() or self.reached_max_attempts(): - self.tag = FlowStateTag.FAILURE - def store_item_to_json(self) -> dict: return self.model_dump() @@ -52,11 +60,4 @@ def reached_max_attempts(self) -> bool: return self.attempts_remaining <= 0 def is_active(self) -> bool: - return not self.is_expired() and not self.reached_max_attempts() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] - -class FlowResponse(BaseModel): - - flow_state: FlowState = FlowState() - flow_error_tag: FlowErrorTag = FlowErrorTag.NONE - token_response: Optional[TokenResponse] = None - sign_in_resource: Optional[SignInResource] = None + return not self.is_expired() and not self.reached_max_attempts() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py similarity index 80% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py index 52e49156..4a385cf0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -1,22 +1,16 @@ -from typing import Optional +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. -from ... import TurnContext -from ...storage import Storage +from typing import Optional -from .models import FlowState +from ..storage import Storage +from .flow_state import FlowState -# robrandao: TODO -> context.activity.from_property +# this could be generalized, if needed class FlowStorageClient: - """ - Wrapper around storage that manages sign-in state specific to each user and channel. + """Wrapper around Storage that manages sign-in state specific to each user and channel. Uses the activity's channel_id and from.id to create a key prefix for storage operations. - - Contract with other classes (usage of other classes is enforced in unit tests): - TurnContext.activity.channel_id - TurnContext.activity.from_property.id - - Storage: read(), write(), delete() """ def __init__( diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py similarity index 93% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index ba16fbc0..de167c9a 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -5,26 +5,31 @@ import logging -from enum import Enum +from pydantic import BaseModel from datetime import datetime -from typing import Dict, Optional +from typing import Optional -from microsoft.agents.hosting.core.connector.client import UserTokenClient from microsoft.agents.activity import ( Activity, ActivityTypes, TokenExchangeState, TokenResponse, + SignInResource ) -from microsoft.agents.hosting.core.storage import StoreItem, Storage -from pydantic import BaseModel, PositiveInt -from .models import FlowResponse, FlowState, FlowStateTag, FlowErrorTag -from .utils import raise_if_empty_or_None +from ..connector.client import UserTokenClient +from .flow_state import FlowState, FlowStateTag, FlowErrorTag logger = logging.getLogger(__name__) -class AuthFlow: +class FlowResponse(BaseModel): + """Represents the response for a flow operation.""" + flow_state: FlowState = FlowState() + flow_error_tag: FlowErrorTag = FlowErrorTag.NONE + token_response: Optional[TokenResponse] = None + sign_in_resource: Optional[SignInResource] = None + +class OAuthFlow: """ Manages the OAuth flow. @@ -54,11 +59,8 @@ def __init__( max_attempts: The maximum number of attempts for the flow set when starting a flow (default: 3). """ - raise_if_empty_or_None( - self.__init__.__name__, - flow_state=flow_state, - user_token_client=user_token_client - ) + if not self.flow_state or not user_token_client: + raise ValueError("OAuthFlow.__init__(): flow_state and user_token_client are required") if (not flow_state.abs_oauth_connection_name or not flow_state.ms_app_id or diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py deleted file mode 100644 index 29c1b247..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py +++ /dev/null @@ -1,442 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from __future__ import annotations - -import logging - -from enum import Enum -from datetime import datetime -from typing import Dict, Optional - -from microsoft.agents.hosting.core.connector.client import UserTokenClient -from microsoft.agents.activity import ( - ActionTypes, - ActivityTypes, - CardAction, - Attachment, - OAuthCard, - TokenExchangeState, - TokenResponse, - Activity, -) -from microsoft.agents.activity import ( - TurnContextProtocol as TurnContext, -) -from microsoft.agents.hosting.core.storage import StoreItem, Storage -from pydantic import BaseModel, PositiveInt - -from .message_factory import MessageFactory -from .card_factory import CardFactory - -logger = logging.getLogger(__name__) - -# class FlowStatus(Enum): -# IN_ACTIVE = "not_started" -# IN_PROGRESS = "in_progress" -# COMPLETED = "completed" -# ERROR = "error" -# EXPIRED = "expired" - - -class FlowState(StoreItem, BaseModel): - flow_started: bool = False - # flow_status: FLOW_STATUS - user_token: str = "" - flow_expires: float = 0 - abs_oauth_connection_name: Optional[str] = None - continuation_activity: Optional[Activity] = None - attempts_remaining: PositiveInt = 1 - - def store_item_to_json(self) -> dict: - return self.model_dump() - - @staticmethod - def from_json_to_store_item(json_data: dict) -> "StoreItem": - return FlowState.model_validate(json_data) - - -class OAuthFlow: - """ - Manages the OAuth flow. - """ - - def __init__( - self, - storage: Storage, - abs_oauth_connection_name: str, - user_token_client: Optional[UserTokenClient] = None, - messages_configuration: dict[str, str] = None, - **kwargs, - ): - """ - Creates a new instance of OAuthFlow. - - Args: - user_state: The user state. - abs_oauth_connection_name: The OAuth connection name. - user_token_client: Optional user token client. - messages_configuration: Optional messages configuration for backward compatibility. - - Kwargs: - flow_total_tries: The total number of auth attempts made by the user during a single flow - """ - if not abs_oauth_connection_name: - raise ValueError( - "OAuthFlow.__init__: connectionName expected but not found" - ) - - # Handle backward compatibility with messages_configuration - self.messages_configuration = messages_configuration or {} - - # Initialize properties - self.abs_oauth_connection_name = abs_oauth_connection_name - self.user_token_client = user_token_client - self.token_exchange_id: Optional[str] = None - - # Initialize state and flow state - self._storage = storage - self.flow_state = None - - self.initial_attempts_remaining = kwargs.get("initial_attempts_remaining", 3) - - async def get_user_token(self, context: TurnContext) -> TokenResponse: - """ - Retrieves the user token from the user token service. - - Args: - context: The turn context containing the activity information. - - Returns: - The user token response. - - Raises: - ValueError: If the channelId or from properties are not set in the activity. - """ - await self._initialize_token_client(context) - - if not context.activity.from_property: - raise ValueError("User ID is not set in the activity.") - - if not context.activity.channel_id: - raise ValueError("Channel ID is not set in the activity.") - - return await self.user_token_client.user_token.get_token( - user_id=context.activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=context.activity.channel_id, - ) - - async def reset_to_initial_flow_state(self, context: TurnContext) -> None: - self.flow_state.flow_started = True - self.flow_state.flow_expires = datetime.now().timestamp() + 30000 - self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name - self.flow_state.attempts_remaining = self.initial_attempts_remaining - await self._save_flow_state(context) - - async def reset_to_finished_flow_state(self, context: TurnContext) -> None: - self.flow_state.flow_started = False - self.flow_state.flow_expires = 0 - self.flow_state.attempts_remaining = 0 - self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name - await self._save_flow_state(context) - - async def begin_flow(self, context: TurnContext) -> TokenResponse: - """ - Begins the OAuth flow. - - Args: - context: The turn context. - - Returns: - A TokenResponse object. - """ - self.flow_state = FlowState() - - if not self.abs_oauth_connection_name: - raise ValueError("connectionName is not set") - - await self._initialize_token_client(context) - - activity = context.activity - - # Try to get existing token first - user_token = await self.user_token_client.user_token.get_token( - user_id=activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=activity.channel_id, - ) - - if user_token and user_token.token: - # Already have token, return it - self.flow_state.flow_started = False - self.flow_state.flow_expires = 0 - self.flow_state.attempts_remaining = 0 - self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name - await self._save_flow_state(context) - return user_token - - # No token, need to start sign-in flow - token_exchange_state = TokenExchangeState( - connection_name=self.abs_oauth_connection_name, - conversation=activity.get_conversation_reference(), - relates_to=activity.relates_to, - ms_app_id=context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims[ - "aud" - ], - ) - - signing_resource = ( - await self.user_token_client.agent_sign_in.get_sign_in_resource( - state=token_exchange_state.get_encoded_state(), - ) - ) - - # Create the OAuth card - o_card: Attachment = CardFactory.oauth_card( - OAuthCard( - text=self.messages_configuration.get("card_title", "Sign in"), - connection_name=self.abs_oauth_connection_name, - buttons=[ - CardAction( - title=self.messages_configuration.get("button_text", "Sign in"), - type=ActionTypes.signin, - value=signing_resource.sign_in_link, - channel_data=None, - ) - ], - token_exchange_resource=signing_resource.token_exchange_resource, - token_post_resource=signing_resource.token_post_resource, - ) - ) - - # Send the card to the user - await context.send_activity(MessageFactory.attachment(o_card)) - - # Update flow state - await self.reset_to_initial_flow_state(context) - - # Return in-progress response - return TokenResponse() - - async def continue_flow(self, context: TurnContext) -> TokenResponse: - """ - Continues the OAuth flow. - - Args: - context: The turn context. - - Returns: - A TokenResponse object. - """ - await self._initialize_token_client(context) - - if self.flow_state and ( - ( - self.flow_state.flow_expires != 0 - and datetime.now().timestamp() > self.flow_state.flow_expires - ) - or (self.flow_state.attempts_remaining <= 0) - ): - # self.flow_state = False - await context.send_activity( - MessageFactory.text( - self.messages_configuration.get( - "session_expired_messages", - "Sign-in session expired. Please try again.", - ) - ) - ) - return TokenResponse() - - cont_flow_activity = context.activity - - # Handle message type activities (typically when the user enters a code) - if cont_flow_activity.type == ActivityTypes.message: - self.flow_state.attempts_remaining -= 1 - logger.info( - f"Attempts remaining in this flow: {self.flow_state.attempts_remaining}" - ) - - magic_code = cont_flow_activity.text - - # Validate magic code format (6 digits) - if magic_code and magic_code.isdigit() and len(magic_code) == 6: - result = await self.user_token_client.user_token.get_token( - user_id=cont_flow_activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=cont_flow_activity.channel_id, - code=magic_code, - ) - - if result and result.token: - await self.reset_to_finished_flow_state(context) - return result - else: - await context.send_activity( - MessageFactory.text("Invalid code. Please try again.") - ) - self.flow_state.flow_started = True - self.flow_state.flow_expires = datetime.now().timestamp() + 30000 - await self._save_flow_state(context) - return TokenResponse() - else: - await context.send_activity( - MessageFactory.text( - "Invalid code format. Please enter a 6-digit code." - ) - ) - return TokenResponse() - - # Handle verify state invoke activity - if ( - cont_flow_activity.type == ActivityTypes.invoke - and cont_flow_activity.name == "signin/verifyState" - ): - self.flow_state.attempts_remaining -= 1 - logger.info( - f"Attempts remaining in this flow: {self.flow_state.attempts_remaining}" - ) - - token_verify_state = cont_flow_activity.value - magic_code = token_verify_state.get("state") - - result = await self.user_token_client.user_token.get_token( - user_id=cont_flow_activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=cont_flow_activity.channel_id, - code=magic_code, - ) - - if result and result.token: - self.flow_state.flow_started = False - self.flow_state.abs_oauth_connection_name = ( - self.abs_oauth_connection_name - ) - await self._save_flow_state(context) - return result - return TokenResponse() - - # Handle token exchange invoke activity - if ( - cont_flow_activity.type == ActivityTypes.invoke - and cont_flow_activity.name == "signin/tokenExchange" - ): - self.flow_state.attempts_remaining -= 1 - logger.info( - f"Attempts remaining in this flow: {self.flow_state.attempts_remaining}" - ) - - token_exchange_request = cont_flow_activity.value - - # Dedupe checks to prevent duplicate processing - token_exchange_id = token_exchange_request.get("id") - if self.token_exchange_id == token_exchange_id: - # Already processed this request - return TokenResponse() - - # Store this request ID - self.token_exchange_id = token_exchange_id - - # Exchange the token - user_token_resp = await self.user_token_client.user_token.exchange_token( - user_id=cont_flow_activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=cont_flow_activity.channel_id, - body=token_exchange_request, - ) - - if user_token_resp and user_token_resp.token: - self.flow_state.flow_started = False - await self._save_flow_state(context) - return user_token_resp - else: - self.flow_state.flow_started = True - return TokenResponse() - - return TokenResponse() - - async def sign_out(self, context: TurnContext) -> None: - """ - Signs the user out. - - Args: - context: The turn context. - """ - await self._initialize_token_client(context) - - await self.user_token_client.user_token.sign_out( - user_id=context.activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=context.activity.channel_id, - ) - - if self.flow_state: - self.flow_state.flow_expires = 0 - self.flow_state.attempts_remaining = 0 - await self._save_flow_state(context) - - async def _get_flow_state(self, context: TurnContext) -> FlowState: - """ - Gets the user state. - - Args: - context: The turn context. - - Returns: - The user state. - """ - storage_key = self._get_storage_key(context) - - storage_result: Dict[str, FlowState] | None = await self._storage.read( - [storage_key], target_cls=FlowState - ) - if not storage_result or storage_key not in storage_result: - return FlowState() - return storage_result[storage_key] - - async def _save_flow_state(self, context: TurnContext) -> None: - """ - Saves the flow state to the user state. - Args: - context: The turn context. - """ - await self._storage.write({self._get_storage_key(context): self.flow_state}) - - async def _initialize_token_client(self, context: TurnContext) -> None: - """ - Initializes the user token client if not already set. - - Args: - context: The turn context. - """ - - # TODO: Change this to caching when the story is implemented, for now we're getting it from TurnContext (new with every request) - self.user_token_client = context.turn_state.get( - context.adapter.USER_TOKEN_CLIENT_KEY - ) - - def _get_storage_key(self, context: TurnContext) -> str: - """ - Gets the storage key for the flow state. - - Args: - context: The turn context. - - Returns: - The storage key. - """ - channel_id = context.activity.channel_id - if not channel_id: - raise ValueError("Channel ID is not set in the activity.") - user_id = ( - context.activity.from_property.id - if context.activity.from_property - else None - ) - if not user_id: - raise ValueError("User ID is not set in the activity.") - - return ( - f"oauth/{self.abs_oauth_connection_name}/{channel_id}/{user_id}/flowState" - ) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py index 77b4693f..ba46be2c 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py @@ -1,8 +1,4 @@ -from datetime import datetime -from typing import Callable - import pytest -from pydantic import BaseModel from microsoft.agents.activity import ( Activity, @@ -12,22 +8,19 @@ TokenExchangeState, ConversationReference, ChannelAccount, - ConversationAccount ) -from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow - -from microsoft.agents.hosting.core.app.oauth.models import ( +from microsoft.agents.hosting.core.oauth import ( + OAuthFlow, FlowErrorTag, - FlowState, - FlowStateTag, + FlowStateTag ) from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase -from .tools.oauth_test_utils import TEST_DEFAULTS - +# test constants +from .tools.testing_oauth import * -class TestAuthFlowUtils: +class TestOAuthFlowUtils: def create_user_token_client(self, mocker, get_token_return=None): @@ -45,7 +38,7 @@ def create_user_token_client(self, mocker, get_token_return=None): @pytest.fixture def user_token_client(self, mocker): - return self.create_user_token_client(mocker, get_token_return=TEST_DEFAULTS.RES_TOKEN) + return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): # def conv_ref(): @@ -56,40 +49,40 @@ def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", return Activity( type=activity_type, name=name, - from_property=ChannelAccount(id=TEST_DEFAULTS.USER_ID), - channel_id=TEST_DEFAULTS.CHANNEL_ID, + from_property=ChannelAccount(id=USER_ID), + channel_id=CHANNEL_ID, # get_conversation_reference=mocker.Mock(return_value=conv_ref), relates_to=mocker.MagicMock(ConversationReference), value=value, text=text ) - @pytest.fixture(params=TEST_DEFAULTS.ALL()) + @pytest.fixture(params=FLOW_STATES.ALL()) def sample_flow_state(self, request): return request.param.model_copy() - @pytest.fixture(params=TEST_DEFAULTS.FAILED()) + @pytest.fixture(params=FLOW_STATES.FAILED()) def sample_failed_flow_state(self, request): return request.param.model_copy() - @pytest.fixture(params=TEST_DEFAULTS.INACTIVE()) + @pytest.fixture(params=FLOW_STATES.INACTIVE()) def sample_inactive_flow_state(self, request): return request.param.model_copy() - @pytest.fixture(params=TEST_DEFAULTS.ACTIVE()) + @pytest.fixture(params=FLOW_STATES.ACTIVE()) def sample_active_flow_state(self, request): return request.param.model_copy() @pytest.fixture def flow(self, sample_flow_state, user_token_client): - return AuthFlow(sample_flow_state, user_token_client) + return OAuthFlow(sample_flow_state, user_token_client) -class TestAuthFlow(TestAuthFlowUtils): +class TestOAuthFlow(TestOAuthFlowUtils): def test_init_no_user_token_client(self, sample_flow_state): with pytest.raises(ValueError): - AuthFlow(sample_flow_state, None) + OAuthFlow(sample_flow_state, None) @pytest.mark.parametrize("missing_value", [ "abs_oauth_connection_name", @@ -98,42 +91,42 @@ def test_init_no_user_token_client(self, sample_flow_state): "user_id" ]) def test_init_errors(self, missing_value, user_token_client): - flow_state = TEST_DEFAULTS.STARTED_FLOW.model_copy() + flow_state = FLOW_STATES.STARTED_FLOW.model_copy() flow_state.__setattr__(missing_value, None) with pytest.raises(ValueError): - AuthFlow(flow_state, user_token_client) + OAuthFlow(flow_state, user_token_client) flow_state.__setattr__(missing_value, "") with pytest.raises(ValueError): - AuthFlow(flow_state, user_token_client) + OAuthFlow(flow_state, user_token_client) def test_init_with_state(self, sample_flow_state, user_token_client): - flow = AuthFlow(sample_flow_state, user_token_client) + flow = OAuthFlow(sample_flow_state, user_token_client) assert flow.flow_state == sample_flow_state def test_flow_state_prop_copy(self, flow): flow_state = flow.flow_state flow_state.user_id = (flow_state.user_id + "_copy") - assert flow.flow_state.user_id == TEST_DEFAULTS.USER_ID - assert flow_state.user_id == f"{TEST_DEFAULTS.USER_ID}_copy" + assert flow.flow_state.user_id == USER_ID + assert flow_state.user_id == f"{USER_ID}_copy" @pytest.mark.asyncio async def test_get_user_token_success(self, sample_flow_state, user_token_client): # setup - flow = AuthFlow(sample_flow_state, user_token_client) + flow = OAuthFlow(sample_flow_state, user_token_client) expected_final_flow_state = sample_flow_state - expected_final_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + expected_final_flow_state.user_token = RES_TOKEN # test token_response = await flow.get_user_token() token = token_response.token # verify - assert token == TEST_DEFAULTS.RES_TOKEN + assert token == RES_TOKEN assert flow.flow_state == expected_final_flow_state user_token_client.user_token.get_token.assert_called_once_with( - user_id=TEST_DEFAULTS.USER_ID, - connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, - channel_id=TEST_DEFAULTS.CHANNEL_ID, + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, magic_code=None ) @@ -141,7 +134,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client async def test_get_user_token_failure(self, mocker, sample_flow_state): # setup user_token_client = self.create_user_token_client(mocker, get_token_return=None) - flow = AuthFlow(sample_flow_state, user_token_client) + flow = OAuthFlow(sample_flow_state, user_token_client) expected_final_flow_state = flow.flow_state # robrandao: TODO -> what happens if fails and has user_token? # test @@ -151,16 +144,16 @@ async def test_get_user_token_failure(self, mocker, sample_flow_state): assert token_response == TokenResponse() assert flow.flow_state == expected_final_flow_state user_token_client.user_token.get_token.assert_called_once_with( - user_id=TEST_DEFAULTS.USER_ID, - connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, - channel_id=TEST_DEFAULTS.CHANNEL_ID, + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, magic_code=None ) @pytest.mark.asyncio async def test_sign_out(self, sample_flow_state, user_token_client): # setup - flow = AuthFlow(sample_flow_state, user_token_client) + flow = OAuthFlow(sample_flow_state, user_token_client) expected_flow_state = sample_flow_state expected_flow_state.user_token = "" expected_flow_state.tag = FlowStateTag.NOT_STARTED @@ -170,19 +163,19 @@ async def test_sign_out(self, sample_flow_state, user_token_client): # verify user_token_client.user_token.sign_out.assert_called_once_with( - user_id=TEST_DEFAULTS.USER_ID, - connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, - channel_id=TEST_DEFAULTS.CHANNEL_ID + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID ) assert flow.flow_state == expected_flow_state @pytest.mark.asyncio async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_client): # setup - flow = AuthFlow(sample_flow_state, user_token_client) + flow = OAuthFlow(sample_flow_state, user_token_client) activity = mocker.Mock(spec=Activity) expected_flow_state = sample_flow_state - expected_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + expected_flow_state.user_token = RES_TOKEN # test response = await flow.begin_flow(activity) @@ -195,11 +188,11 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ assert response.flow_state == flow_state assert response.sign_in_resource is None # No sign-in resource in this case assert response.flow_error_tag == FlowErrorTag.NONE - assert response.token_response.token == TEST_DEFAULTS.RES_TOKEN + assert response.token_response.token == RES_TOKEN user_token_client.user_token.get_token.assert_called_once_with( - user_id=TEST_DEFAULTS.USER_ID, - connection_name=TEST_DEFAULTS.ABS_OAUTH_CONNECTION_NAME, - channel_id=TEST_DEFAULTS.CHANNEL_ID, + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, # magic_code=None is an implementation detail, and ideally # shouldn't be part of the test magic_code=None @@ -223,7 +216,7 @@ async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_ activity = self.create_activity(mocker) # setup - flow = AuthFlow(sample_flow_state, user_token_client) + flow = OAuthFlow(sample_flow_state, user_token_client) expected_flow_state = sample_flow_state expected_flow_state.user_token = "" expected_flow_state.tag = FlowStateTag.BEGIN @@ -248,7 +241,7 @@ async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_ async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state, user_token_client): # setup activity = mocker.Mock() - flow = AuthFlow(sample_inactive_flow_state, user_token_client) + flow = OAuthFlow(sample_inactive_flow_state, user_token_client) expected_flow_state = sample_inactive_flow_state expected_flow_state.tag = FlowStateTag.FAILURE @@ -264,7 +257,7 @@ async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state async def helper_continue_flow_failure(self, active_flow_state, user_token_client, activity, flow_error_tag): # setup - flow = AuthFlow(active_flow_state, user_token_client) + flow = OAuthFlow(active_flow_state, user_token_client) expected_flow_state = active_flow_state expected_flow_state.tag = FlowStateTag.CONTINUE if active_flow_state.attempts_remaining > 1 else FlowStateTag.FAILURE expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining - 1 @@ -281,10 +274,10 @@ async def helper_continue_flow_failure(self, active_flow_state, user_token_clien async def helper_continue_flow_success(self, active_flow_state, user_token_client, activity): # setup - flow = AuthFlow(active_flow_state, user_token_client) + flow = OAuthFlow(active_flow_state, user_token_client) expected_flow_state = active_flow_state expected_flow_state.tag = FlowStateTag.COMPLETE - expected_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN + expected_flow_state.user_token = RES_TOKEN expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining # test @@ -295,7 +288,7 @@ async def helper_continue_flow_success(self, active_flow_state, user_token_clien # verify assert flow_response.flow_state == flow_state assert expected_flow_state == flow_state - assert flow_response.token_response == TokenResponse(token=TEST_DEFAULTS.RES_TOKEN) + assert flow_response.token_response == TokenResponse(token=RES_TOKEN) assert flow_response.flow_error_tag == FlowErrorTag.NONE @pytest.mark.asyncio @@ -375,7 +368,7 @@ async def test_continue_flow_active_sign_in_token_exchange_error(self, mocker, s @pytest.mark.asyncio async def test_continue_flow_active_sign_in_token_exchange_success(self, mocker, sample_active_flow_state, user_token_client): token_exchange_request = {} - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token=TEST_DEFAULTS.RES_TOKEN)) + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) user_token_client.user_token.exchange_token.assert_called_once_with( @@ -389,14 +382,14 @@ async def test_continue_flow_active_sign_in_token_exchange_success(self, mocker, async def test_continue_flow_invalid_invoke_name(self, mocker, sample_active_flow_state, user_token_client): with pytest.raises(ValueError): activity = self.create_activity(mocker, ActivityTypes.invoke, name="other", value={}) - flow = AuthFlow(sample_active_flow_state, user_token_client) + flow = OAuthFlow(sample_active_flow_state, user_token_client) await flow.continue_flow(activity) @pytest.mark.asyncio async def test_continue_flow_invalid_activity_type(self, mocker, sample_active_flow_state, user_token_client): with pytest.raises(ValueError): activity = self.create_activity(mocker, ActivityTypes.command, name="other", value={}) - flow = AuthFlow(sample_active_flow_state, user_token_client) + flow = OAuthFlow(sample_active_flow_state, user_token_client) await flow.continue_flow(activity) # robrandao: TODO -> test begin_or_continue_flow \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index 0291a413..67cbfa44 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -1,29 +1,26 @@ -import jwt - import pytest +import jwt + from microsoft.agents.activity import ( ActivityTypes, TokenResponse ) -from microsoft.agents.hosting.core.app.oauth import ( - Authorization, - FlowStorageClient, - FlowErrorTag, - FlowStateTag, -) from microsoft.agents.hosting.core import MemoryStorage from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase -from microsoft.agents.hosting.core.app.oauth.auth_flow import AuthFlow - -from .tools.oauth_test_env import ( - TEST_DEFAULTS, - STORAGE_INIT_DATA +from microsoft.agents.hosting.core.app.oauth import Authorization +from microsoft.agents.hosting.core.oauth import ( + OAuthFlow, + FlowStorageClient, + FlowErrorTag, + FlowStateTag ) +# test constants +from .tools.testing_oauth import * from .tools.testing_authorization import ( TestingConnectionManager, create_test_auth_handler diff --git a/libraries/microsoft-agents-hosting-core/tests/models_test.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py similarity index 91% rename from libraries/microsoft-agents-hosting-core/tests/models_test.py rename to libraries/microsoft-agents-hosting-core/tests/test_flow_state.py index e9a03b5f..dabbcc69 100644 --- a/libraries/microsoft-agents-hosting-core/tests/models_test.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py @@ -2,11 +2,12 @@ import pytest -from microsoft.agents.hosting.core.app.oauth.models import FlowState, FlowStateTag +from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag class TestFlowState: def test_refresh_to_failure_expired(self): + """Test that the flow state refreshes to failure when expired.""" flow_state = FlowState( tag=FlowStateTag.CONTINUE, attempts_remaining=1, @@ -16,6 +17,7 @@ def test_refresh_to_failure_expired(self): assert flow_state.tag == FlowStateTag.FAILURE def test_refresh_to_failure_max_attempts(self): + """Test that the flow state refreshes to failure when max attempts reached.""" flow_state = FlowState( tag=FlowStateTag.CONTINUE, attempts_remaining=0, @@ -24,6 +26,7 @@ def test_refresh_to_failure_max_attempts(self): assert flow_state.tag == FlowStateTag.FAILURE def test_refresh_unchanged_continue(self): + """Test that the flow state remains unchanged when refreshed with a valid CONTINUE state""" flow_state = FlowState( tag=FlowStateTag.CONTINUE, attempts_remaining=1, @@ -34,6 +37,7 @@ def test_refresh_unchanged_continue(self): assert flow_state.tag == prev_tag def test_refresh_unchanged_begin(self): + """Test that the flow state remains unchanged when refreshed with a valid BEGIN state""" flow_state = FlowState( tag=FlowStateTag.BEGIN, attempts_remaining=10, diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py index 57a65258..f0171fb6 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py @@ -1,12 +1,8 @@ import pytest -from unittest.mock import sentinel from microsoft.agents.hosting.core.storage import MemoryStorage from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem -from microsoft.agents.hosting.core.app.oauth import ( - FlowState, - FlowStorageClient, -) +from microsoft.agents.hosting.core.oauth import FlowState, FlowStorageClient class TestFlowStorageClient: diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py similarity index 56% rename from libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py rename to libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py index db40c96d..bf1aad03 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/oauth_test_env.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py @@ -1,20 +1,21 @@ from datetime import datetime -from microsoft.agents.hosting.core.app.oauth.models import FlowState, FlowStateTag -class TEST_DEFAULTS: +from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag - MS_APP_ID = "__ms_app_id" - CHANNEL_ID = "__channel_id" - USER_ID = "__user_id" - ABS_OAUTH_CONNECTION_NAME = "__connection_name" - RES_TOKEN = "__res_token" +MS_APP_ID = "__ms_app_id" +CHANNEL_ID = "__channel_id" +USER_ID = "__user_id" +ABS_OAUTH_CONNECTION_NAME = "__connection_name" +RES_TOKEN = "__res_token" - DEF_ARGS = { - "ms_app_id": MS_APP_ID, - "channel_id": CHANNEL_ID, - "user_id": USER_ID, - "abs_oauth_connection_name": ABS_OAUTH_CONNECTION_NAME - } +DEF_ARGS = { + "ms_app_id": MS_APP_ID, + "channel_id": CHANNEL_ID, + "user_id": USER_ID, + "abs_oauth_connection_name": ABS_OAUTH_CONNECTION_NAME +} + +class FLOW_STATES: STARTED_FLOW = FlowState( **DEF_ARGS, @@ -64,55 +65,55 @@ class TEST_DEFAULTS: attempts_remaining=0, expires_at=datetime.now().timestamp() + 1000000 ) - + FAIL_BY_EXP_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.FAILURE, attempts_remaining=2, expires_at=0 ) - - @classmethod - def __format(cls, lst): + + @staticmethod + def clone_state_list(lst): return [ flow_state.model_copy() for flow_state in lst ] - - @classmethod - def ALL(cls): - return cls.__format([ - cls.STARTED_FLOW, - cls.STARTED_FLOW_ONE_RETRY, - cls.ACTIVE_FLOW, - cls.ACTIVE_FLOW_ONE_RETRY, - cls.ACTIVE_EXP_FLOW, - cls.COMPLETED_FLOW, - cls.FAIL_BY_ATTEMPTS_FLOW, - cls.FAIL_BY_EXP_FLOW + + @staticmethod + def ALL(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW ]) - - @classmethod - def FAILED(cls): - return cls.__format([ - cls.ACTIVE_EXP_FLOW, - cls.FAIL_BY_ATTEMPTS_FLOW, - cls.FAIL_BY_EXP_FLOW + + @staticmethod + def FAILED(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW ]) - - @classmethod - def ACTIVE(cls): - return cls.__format([ - cls.STARTED_FLOW, - cls.STARTED_FLOW_ONE_RETRY, - cls.ACTIVE_FLOW, - cls.ACTIVE_FLOW_ONE_RETRY, + + @staticmethod + def ACTIVE(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, ]) - - @classmethod - def INACTIVE(cls): - return cls.__format([ - cls.ACTIVE_EXP_FLOW, - cls.COMPLETED_FLOW, - cls.FAIL_BY_ATTEMPTS_FLOW, - cls.FAIL_BY_EXP_FLOW + + @staticmethod + def INACTIVE(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW ]) def flow_key(channel_id, user_id, handler_id): @@ -121,12 +122,12 @@ def flow_key(channel_id, user_id, handler_id): STORAGE_SAMPLE_DICT = { "user_id": "123", "session_id": "abc", - flow_key("webchat", "Alice", "graph"): TEST_DEFAULTS.COMPLETED_FLOW.model_copy(), - flow_key("webchat", "Alice", "github"): TEST_DEFAULTS.ACTIVE_FLOW_ONE_RETRY.model_copy(), - flow_key("teams", "Alice", "graph"): TEST_DEFAULTS.STARTED_FLOW.model_copy(), - flow_key("webchat", "Bob", "graph"): TEST_DEFAULTS.ACTIVE_EXP_FLOW.model_copy(), - flow_key("teams", "Bob", "slack"): TEST_DEFAULTS.STARTED_FLOW.model_copy(), - flow_key("webchat", "Chuck", "github"): TEST_DEFAULTS.FAIL_BY_ATTEMPTS_FLOW.model_copy(), + flow_key("webchat", "Alice", "graph"): FLOW_STATES.COMPLETED_FLOW.model_copy(), + flow_key("webchat", "Alice", "github"): FLOW_STATES.ACTIVE_FLOW_ONE_RETRY.model_copy(), + flow_key("teams", "Alice", "graph"): FLOW_STATES.STARTED_FLOW.model_copy(), + flow_key("webchat", "Bob", "graph"): FLOW_STATES.ACTIVE_EXP_FLOW.model_copy(), + flow_key("teams", "Bob", "slack"): FLOW_STATES.STARTED_FLOW.model_copy(), + flow_key("webchat", "Chuck", "github"): FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW.model_copy(), } def STORAGE_INIT_DATA(): From ad5b91fea44801e3b82070cd8e86ab190c285813 Mon Sep 17 00:00:00 2001 From: Rodrigo Brandao Date: Thu, 21 Aug 2025 12:54:02 -0700 Subject: [PATCH 15/32] Passing all Authorization unit tests --- .../agents/hosting/core/app/__init__.py | 2 +- .../hosting/core/app/agent_application.py | 1738 ++++++++--------- .../agents/hosting/core/app/app_options.py | 4 +- .../app/auth/tests/__authorization_test.py | 340 ---- .../hosting/core/app/auth/tests/conftest.py | 0 .../core/app/{auth => oauth}/__init__.py | 0 .../core/app/{auth => oauth}/auth_handler.py | 92 +- .../core/app/{auth => oauth}/authorization.py | 832 ++++---- .../agents/hosting/core/oauth/__init__.py | 5 +- .../agents/hosting/core/oauth/flow_state.py | 10 +- .../hosting/core/oauth/flow_storage_client.py | 20 +- .../agents/hosting/core/oauth/oauth_flow.py | 15 +- .../tests/test_authorization.py | 1113 ++++++----- .../tests/test_flow_state.py | 177 +- .../tests/test_flow_storage_client.py | 301 +-- .../{test_auth_flow.py => test_oauth_flow.py} | 788 ++++---- .../tests/tools/mock_user_token_client.py | 178 +- .../tests/tools/testing_authorization.py | 496 ++--- .../tests/tools/testing_oauth.py | 290 +-- 19 files changed, 3013 insertions(+), 3388 deletions(-) delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/__authorization_test.py delete mode 100644 libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/conftest.py rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{auth => oauth}/__init__.py (100%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{auth => oauth}/auth_handler.py (97%) rename libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/{auth => oauth}/authorization.py (90%) rename libraries/microsoft-agents-hosting-core/tests/{test_auth_flow.py => test_oauth_flow.py} (95%) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py index 67d1df98..4089c3fb 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py @@ -14,7 +14,7 @@ from .typing_indicator import TypingIndicator # Auth -from .auth import ( +from .oauth import ( Authorization, AuthHandler, AuthorizationHandlers, diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 13d3be94..39bed333 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -1,869 +1,869 @@ -""" -Copyright (c) Microsoft Corporation. All rights reserved. -Licensed under the MIT License. -""" - -from __future__ import annotations -import logging -from copy import copy -from functools import partial - -import re -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generic, - List, - Optional, - Pattern, - TypeVar, - Union, - cast, -) - -from microsoft.agents.hosting.core.authorization import Connections - -from microsoft.agents.hosting.core import Agent, TurnContext -from microsoft.agents.activity import ( - Activity, - ActivityTypes, - ConversationUpdateTypes, - MessageReactionTypes, - MessageUpdateTypes, - InvokeResponse, - OAuthCard, - Attachment, - CardAction -) - -from .. import CardFactory, MessageFactory -from .app_error import ApplicationError -from .app_options import ApplicationOptions - -# from .auth import AuthManager, OAuth, OAuthOptions -from .route import Route, RouteHandler -from .state import TurnState -from ..channel_service_adapter import ChannelServiceAdapter -from ..oauth import ( - FlowResponse, - FlowState, - FlowStateTag, -) -from .auth import Authorization -from .typing_indicator import TypingIndicator - -logger = logging.getLogger(__name__) - -StateT = TypeVar("StateT", bound=TurnState) -IN_SIGN_IN_KEY = "__InSignInFlow__" - - -class AgentApplication(Agent, Generic[StateT]): - """ - AgentApplication class for routing and processing incoming requests. - - The AgentApplication object replaces the traditional ActivityHandler that - a bot would use. It supports a simpler fluent style of authoring bots - versus the inheritance based approach used by the ActivityHandler class. - - Additionally, it has built-in support for calling into the SDK's AI system - and can be used to create bots that leverage Large Language Models (LLM) - and other AI capabilities. - """ - - typing: TypingIndicator - - _options: ApplicationOptions - _adapter: Optional[ChannelServiceAdapter] = None - _auth: Optional[Authorization] = None - _internal_before_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] - _internal_after_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] - _routes: List[Route[StateT]] = [] - _error: Optional[Callable[[TurnContext, Exception], Awaitable[None]]] = None - _turn_state_factory: Optional[Callable[[TurnContext], StateT]] = None - - def __init__( - self, - options: ApplicationOptions = None, - *, - connection_manager: Connections = None, - authorization: Authorization = None, - **kwargs, - ) -> None: - """ - Creates a new AgentApplication instance. - """ - self.typing = TypingIndicator() - self._routes = [] - - configuration = kwargs - - logger.debug(f"Initializing AgentApplication with options: {options}") - logger.debug( - f"Initializing AgentApplication with configuration: {configuration}" - ) - - if not options: - # TODO: consolidate configuration story - # Take the options from the kwargs and create an ApplicationOptions instance - option_kwargs = dict( - filter( - lambda x: x[0] in ApplicationOptions.__dataclass_fields__, - kwargs.items(), - ) - ) - options = ApplicationOptions(**option_kwargs) - - self._options = options - - if not self._options.storage: - logger.error( - "ApplicationOptions.storage is required and was not configured.", - stack_info=True, - ) - raise ApplicationError( - """ - The `ApplicationOptions.storage` property is required and was not configured. - """ - ) - - if options.long_running_messages and ( - not options.adapter or not options.bot_app_id - ): - logger.error( - "ApplicationOptions.long_running_messages requires an adapter and bot_app_id.", - stack_info=True, - ) - raise ApplicationError( - """ - The `ApplicationOptions.long_running_messages` property is unavailable because - no adapter or `bot_app_id` was configured. - """ - ) - - if options.adapter: - self._adapter = options.adapter - - self._turn_state_factory = ( - options.turn_state_factory - or kwargs.get("turn_state_factory", None) - or partial(TurnState.with_storage, self._options.storage) - ) - - # TODO: decide how to initialize the Authorization (params vs options vs kwargs) - if authorization: - self._auth = authorization - else: - auth_options = { - key: value - for key, value in configuration.items() - if key not in ["storage", "connection_manager", "handlers"] - } - self._auth = Authorization( - storage=self._options.storage, - connection_manager=connection_manager, - handlers=options.authorization_handlers, - **auth_options, - ) - - @property - def adapter(self) -> ChannelServiceAdapter: - """ - The bot's adapter. - """ - - if not self._adapter: - logger.error( - "AgentApplication.adapter(): self._adapter is not configured.", - stack_info=True, - ) - raise ApplicationError( - """ - The AgentApplication.adapter property is unavailable because it was - not configured when creating the AgentApplication. - """ - ) - - return self._adapter - - @property - def auth(self): - """ - The application's authentication manager - """ - if not self._auth: - logger.error( - "AgentApplication.auth(): self._auth is not configured.", - stack_info=True, - ) - raise ApplicationError( - """ - The `AgentApplication.auth` property is unavailable because - no Auth options were configured. - """ - ) - - return self._auth - - @property - def options(self) -> ApplicationOptions: - """ - The application's configured options. - """ - return self._options - - def activity( - self, - activity_type: Union[str, ActivityTypes, List[Union[str, ActivityTypes]]], - *, - auth_handlers: Optional[List[str]] = None, - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.activity("event") - async def on_event(context: TurnContext, state: TurnState): - print("hello world!") - return True - ``` - - #### Args: - - `type`: The type of the activity - """ - - def __selector(context: TurnContext): - return activity_type == context.activity.type - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering activity handler for route handler {func.__name__} with type: {activity_type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def message( - self, - select: Union[str, Pattern[str], List[Union[str, Pattern[str]]]], - *, - auth_handlers: Optional[List[str]] = None, - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.message("hi") - async def on_hi_message(context: TurnContext, state: TurnState): - print("hello!") - return True - - #### Args: - - `select`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if context.activity.type != ActivityTypes.message: - return False - - text = context.activity.text if context.activity.text else "" - if isinstance(select, Pattern): - hits = re.fullmatch(select, text) - return hits is not None - - return text == select - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering message handler for route handler {func.__name__} with select: {select} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def conversation_update( - self, - type: ConversationUpdateTypes, - *, - auth_handlers: Optional[List[str]] = None, - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.conversation_update("channelCreated") - async def on_channel_created(context: TurnContext, state: TurnState): - print("a new channel was created!") - return True - - ``` - - #### Args: - - `type`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if context.activity.type != ActivityTypes.conversation_update: - return False - - if type == "membersAdded": - if isinstance(context.activity.members_added, List): - return len(context.activity.members_added) > 0 - return False - - if type == "membersRemoved": - if isinstance(context.activity.members_removed, List): - return len(context.activity.members_removed) > 0 - return False - - if isinstance(context.activity.channel_data, object): - data = vars(context.activity.channel_data) - return data["event_type"] == type - - return False - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering conversation update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def message_reaction( - self, type: MessageReactionTypes, *, auth_handlers: Optional[List[str]] = None - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.message_reaction("reactionsAdded") - async def on_reactions_added(context: TurnContext, state: TurnState): - print("reactions was added!") - return True - ``` - - #### Args: - - `type`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if context.activity.type != ActivityTypes.message_reaction: - return False - - if type == "reactionsAdded": - if isinstance(context.activity.reactions_added, List): - return len(context.activity.reactions_added) > 0 - return False - - if type == "reactionsRemoved": - if isinstance(context.activity.reactions_removed, List): - return len(context.activity.reactions_removed) > 0 - return False - - return False - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering message reaction handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def message_update( - self, type: MessageUpdateTypes, *, auth_handlers: Optional[List[str]] = None - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.message_update("editMessage") - async def on_edit_message(context: TurnContext, state: TurnState): - print("message was edited!") - return True - ``` - - #### Args: - - `type`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if type == "editMessage": - if ( - context.activity.type == ActivityTypes.message_update - and isinstance(context.activity.channel_data, dict) - ): - data = context.activity.channel_data - return data["event_type"] == type - return False - - if type == "softDeleteMessage": - if ( - context.activity.type == ActivityTypes.message_delete - and isinstance(context.activity.channel_data, dict) - ): - data = context.activity.channel_data - return data["event_type"] == type - return False - - if type == "undeleteMessage": - if ( - context.activity.type == ActivityTypes.message_update - and isinstance(context.activity.channel_data, dict) - ): - data = context.activity.channel_data - return data["event_type"] == type - return False - return False - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering message update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def handoff(self, *, auth_handlers: Optional[List[str]] = None) -> Callable[ - [Callable[[TurnContext, StateT, str], Awaitable[None]]], - Callable[[TurnContext, StateT, str], Awaitable[None]], - ]: - """ - Registers a handler to handoff conversations from one copilot to another. - ```python - # Use this method as a decorator - @app.handoff - async def on_handoff( - context: TurnContext, state: TurnState, continuation: str - ): - print(query) - ``` - """ - - def __selector(context: TurnContext) -> bool: - return ( - context.activity.type == ActivityTypes.invoke - and context.activity.name == "handoff/action" - ) - - def __call( - func: Callable[[TurnContext, StateT, str], Awaitable[None]], - ) -> Callable[[TurnContext, StateT, str], Awaitable[None]]: - async def __handler(context: TurnContext, state: StateT): - if not context.activity.value: - return False - await func(context, state, context.activity.value["continuation"]) - await context.send_activity( - Activity( - type=ActivityTypes.invoke_response, - value=InvokeResponse(status=200), - ) - ) - return True - - logger.debug( - f"Registering handoff handler for route handler {func.__name__} with auth handlers: {auth_handlers}" - ) - - self._routes.append( - Route[StateT](__selector, __handler, True, auth_handlers) - ) - self._routes = sorted(self._routes, key=lambda route: not route.is_invoke) - return func - - return __call - - def on_sign_in_success( - self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] - ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: - """ - Registers a new event listener that will be executed when a user successfully signs in. - - ```python - # Use this method as a decorator - @app.on_sign_in_success - async def sign_in_success(context: TurnContext, state: TurnState): - print("hello world!") - return True - ``` - """ - - if self._auth: - logger.debug( - f"Registering sign-in success handler for route handler {func.__name__}" - ) - self._auth.on_sign_in_success(func) - else: - logger.error( - f"Failed to register sign-in success handler for route handler {func.__name__}", - stack_info=True, - ) - raise ApplicationError( - """ - The `AgentApplication.on_sign_in_success` method is unavailable because - no Auth options were configured. - """ - ) - return func - - def on_sign_in_failure( - self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] - ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: - """ - Registers a new event listener that will be executed when a user fails to sign in. - - ```python - # Use this method as a decorator - @app.on_sign_in_failure - async def sign_in_failure(context: TurnContext, state: TurnState): - print("hello world!") - return True - ``` - """ - - if self._auth: - logger.debug( - f"Registering sign-in failure handler for route handler {func.__name__}" - ) - self._auth.on_sign_in_failure(func) - else: - logger.error( - f"Failed to register sign-in failure handler for route handler {func.__name__}", - stack_info=True, - ) - raise ApplicationError( - """ - The `AgentApplication.on_sign_in_failure` method is unavailable because - no Auth options were configured. - """ - ) - return func - - def error( - self, func: Callable[[TurnContext, Exception], Awaitable[None]] - ) -> Callable[[TurnContext, Exception], Awaitable[None]]: - """ - Registers an error handler that will be called anytime - the app throws an Exception - - ```python - # Use this method as a decorator - @app.error - async def on_error(context: TurnContext, err: Exception): - print(err.message) - ``` - """ - - logger.debug(f"Registering the error handler {func.__name__} ") - self._error = func - - if self._adapter: - logger.debug( - f"Registering for adapter {self._adapter.__class__.__name__} the error handler {func.__name__} " - ) - self._adapter.on_turn_error = func - - return func - - def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): - """ - Custom Turn State Factory - """ - logger.debug(f"Setting custom turn state factory: {func.__name__}") - self._turn_state_factory = func - return func - - async def _handle_flow_response(self, context: TurnContext, flow_response: FlowResponse) -> None: - - flow_state: FlowState = flow_response.flow_state - in_flow_activity = flow_response.in_flow_activity - - if in_flow_activity: - await context.send_activity(in_flow_activity) - - if flow_state.tag == FlowStateTag.BEGIN: - # Create the OAuth card - o_card: Attachment = CardFactory.oauth_card( - OAuthCard( - text=self.messages_configuration.get("card_title", "Sign in"), - connection_name=self.abs_oauth_connection_name, - buttons=[ - CardAction( - title=self.messages_configuration.get("button_text", "Sign in"), - type=ActionTypes.signin, - value=signing_resource.sign_in_link, - channel_data=None, - ) - ], - token_exchange_resource=signing_resource.token_exchange_resource, - token_post_resource=signing_resource.token_post_resource, - ) - ) - - # Send the card to the user - await context.send_activity(MessageFactory.attachment(o_card)) - elif flow_state.tag == FlowStateTag.FAILURE: - if flow_state.reached_max_retries(): - await context.send_activity( - MessageFactory.text( - self.messages_configuration.get( - "max_retries_reached_messages", - "Sign-in failed. Please try again later.", - ) - ) - ) - elif flow_state.is_expired(): - await context.send_activity( - MessageFactory.text( - self.messages_configuration.get( - "session_expired_messages", - "Sign-in session expired. Please try again.", - ) - )) - else: - logger.warning("Sign-in flow failed for unknown reasons.") - await context.send_activity("Sign-in failed. Please try again.") - - async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: - - prev_flow_state = await self._auth.get_active_flow_state(context) - if self._auth and prev_flow_state: - - logger.debug("Sign-in flow is active for context: %s", context.activity.id) - - flow_response: FlowResponse = await self._auth.begin_or_continue_flow( - context, turn_state, prev_flow_state.handler_id - ) - - await self._handle_flow_response(context, flow_response) - - new_flow_state: FlowState = flow_response.flow_state - token_response: TokenResponse = new_flow_state.token_response - saved_activity: Activity = new_flow_state.continuation_activity.model_copy() - - if token_response and token_response.token: - new_context = copy(context) - new_context.activity = saved_activity - logger.info( - "Resending continuation activity %s", saved_activity.text - ) - await self.on_turn(new_context) - turn_state.delete_value(Authorization.SIGN_IN_STATE_KEY) - await turn_state.save(context) - return True - - return False - - async def on_turn(self, context: TurnContext): - logger.debug( - f"AgentApplication.on_turn(): Processing turn for context: {context.activity.id}" - ) - await self._start_long_running_call(context, self._on_turn) - - async def _on_turn(self, context: TurnContext): - # robrandao: TODO - try: - await self._start_typing(context) - - self._remove_mentions(context) - - logger.debug("Initializing turn state") - turn_state = await self._initialize_state(context) - - if await self._on_turn_auth_intercept(context, turn_state): - return - - logger.debug("Running before turn middleware") - if not await self._run_before_turn_middleware(context, turn_state): - return - - logger.debug("Running file downloads") - await self._handle_file_downloads(context, turn_state) - - logger.debug("Running activity handlers") - await self._on_activity(context, turn_state) - - logger.debug("Running after turn middleware") - if not await self._run_after_turn_middleware(context, turn_state): - await turn_state.save(context) - return - except ApplicationError as err: - logger.error( - f"An application error occurred in the AgentApplication: {err}", - exc_info=True, - ) - await self._on_error(context, err) - finally: - logger.debug("Stopping typing indicator") - self.typing.stop() - - async def _start_typing(self, context: TurnContext): - if self._options.start_typing_timer: - logger.debug("Starting typing indicator for context") - await self.typing.start(context) - - def _remove_mentions(self, context: TurnContext): - if ( - self.options.remove_recipient_mention - and context.activity.type == ActivityTypes.message - ): - context.activity.text = context.remove_recipient_mention(context.activity) - - @staticmethod - def parse_env_vars_configuration(vars: Dict[str, Any]) -> dict: - """ - Parses environment variables and returns a dictionary with the relevant configuration. - """ - result = {} - for key, value in vars.items(): - levels = key.split("__") - current_level = result - last_level = None - for next_level in levels: - if next_level not in current_level: - current_level[next_level] = {} - last_level = current_level - current_level = current_level[next_level] - logger.debug(f"Using environment variable '{key}'") - last_level[levels[-1]] = value - - return { - "AGENT_APPLICATION": result["AGENT_APPLICATION"], - "COPILOT_STUDIO_AGENT": result["COPILOT_STUDIO_AGENT"], - "CONNECTIONS": result["CONNECTIONS"], - "CONNECTIONS_MAP": result["CONNECTIONS_MAP"], - } - - async def _initialize_state(self, context: TurnContext) -> StateT: - if self._turn_state_factory: - logger.debug("Using custom turn state factory") - turn_state = self._turn_state_factory() - else: - logger.debug("Using default turn state factory") - turn_state = TurnState.with_storage(self._options.storage) - await turn_state.load(context, self._options.storage) - - turn_state = cast(StateT, turn_state) - - logger.debug("Loading turn state from storage") - await turn_state.load(context, self._options.storage) - turn_state.temp.input = context.activity.text - return turn_state - - async def _run_before_turn_middleware(self, context: TurnContext, state: StateT): - for before_turn in self._internal_before_turn: - is_ok = await before_turn(context, state) - if not is_ok: - await state.save(context, self._options.storage) - return False - return True - - async def _handle_file_downloads(self, context: TurnContext, state: StateT): - if self._options.file_downloaders and len(self._options.file_downloaders) > 0: - input_files = state.temp.input_files if state.temp.input_files else [] - for file_downloader in self._options.file_downloaders: - logger.info( - f"Using file downloader: {file_downloader.__class__.__name__}" - ) - files = await file_downloader.download_files(context) - input_files.extend(files) - state.temp.input_files = input_files - - def _contains_non_text_attachments(self, context: TurnContext): - non_text_attachments = filter( - lambda a: not a.content_type.startswith("text/html"), - context.activity.attachments, - ) - return len(list(non_text_attachments)) > 0 - - async def _run_after_turn_middleware(self, context: TurnContext, state: StateT): - for after_turn in self._internal_after_turn: - is_ok = await after_turn(context, state) - if not is_ok: - await state.save(context, self._options.storage) - return False - return True - - async def _on_activity(self, context: TurnContext, state: StateT): - for route in self._routes: - if route.selector(context): - if not route.auth_handlers: - await route.handler(context, state) - else: - sign_in_complete = False - for auth_handler_id in route.auth_handlers: - flow_response: FlowResponse = await self._auth.begin_or_continue_flow( - context, state, auth_handler_id - ) - await self._handle_flow_response(context, flow_response.in_flow_activity) - sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE - if not sign_in_complete: - break - - if sign_in_complete: - await route.handler(context, state) - return - logger.warning( - f"No route found for activity type: {context.activity.type} with text: {context.activity.text}" - ) - - async def _start_long_running_call( - self, context: TurnContext, func: Callable[[TurnContext], Awaitable] - ): - if ( - self._adapter - and ActivityTypes.message == context.activity.type - and self._options.long_running_messages - ): - logger.debug( - f"Starting long running call for context: {context.activity.id} with function: {func.__name__}" - ) - return await self._adapter.continue_conversation( - reference=context.get_conversation_reference(context.activity), - callback=func, - bot_app_id=self.options.bot_app_id, - ) - - return await func(context) - - async def _on_error(self, context: TurnContext, err: ApplicationError) -> None: - if self._error: - logger.info( - f"Calling error handler {self._error.__name__} for error: {err}" - ) - return await self._error(context, err) - - logger.error( - f"An error occurred in the AgentApplication: {err}", - exc_info=True, - ) - logger.error(err) - raise err +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from __future__ import annotations +import logging +from copy import copy +from functools import partial + +import re +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + List, + Optional, + Pattern, + TypeVar, + Union, + cast, +) + +from microsoft.agents.hosting.core.authorization import Connections + +from microsoft.agents.hosting.core import Agent, TurnContext +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + ConversationUpdateTypes, + MessageReactionTypes, + MessageUpdateTypes, + InvokeResponse, + OAuthCard, + Attachment, + CardAction +) + +from .. import CardFactory, MessageFactory +from .app_error import ApplicationError +from .app_options import ApplicationOptions + +# from .auth import AuthManager, OAuth, OAuthOptions +from .route import Route, RouteHandler +from .state import TurnState +from ..channel_service_adapter import ChannelServiceAdapter +from ..oauth import ( + FlowResponse, + FlowState, + FlowStateTag, +) +from .oauth import Authorization +from .typing_indicator import TypingIndicator + +logger = logging.getLogger(__name__) + +StateT = TypeVar("StateT", bound=TurnState) +IN_SIGN_IN_KEY = "__InSignInFlow__" + + +class AgentApplication(Agent, Generic[StateT]): + """ + AgentApplication class for routing and processing incoming requests. + + The AgentApplication object replaces the traditional ActivityHandler that + a bot would use. It supports a simpler fluent style of authoring bots + versus the inheritance based approach used by the ActivityHandler class. + + Additionally, it has built-in support for calling into the SDK's AI system + and can be used to create bots that leverage Large Language Models (LLM) + and other AI capabilities. + """ + + typing: TypingIndicator + + _options: ApplicationOptions + _adapter: Optional[ChannelServiceAdapter] = None + _auth: Optional[Authorization] = None + _internal_before_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] + _internal_after_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] + _routes: List[Route[StateT]] = [] + _error: Optional[Callable[[TurnContext, Exception], Awaitable[None]]] = None + _turn_state_factory: Optional[Callable[[TurnContext], StateT]] = None + + def __init__( + self, + options: ApplicationOptions = None, + *, + connection_manager: Connections = None, + authorization: Authorization = None, + **kwargs, + ) -> None: + """ + Creates a new AgentApplication instance. + """ + self.typing = TypingIndicator() + self._routes = [] + + configuration = kwargs + + logger.debug(f"Initializing AgentApplication with options: {options}") + logger.debug( + f"Initializing AgentApplication with configuration: {configuration}" + ) + + if not options: + # TODO: consolidate configuration story + # Take the options from the kwargs and create an ApplicationOptions instance + option_kwargs = dict( + filter( + lambda x: x[0] in ApplicationOptions.__dataclass_fields__, + kwargs.items(), + ) + ) + options = ApplicationOptions(**option_kwargs) + + self._options = options + + if not self._options.storage: + logger.error( + "ApplicationOptions.storage is required and was not configured.", + stack_info=True, + ) + raise ApplicationError( + """ + The `ApplicationOptions.storage` property is required and was not configured. + """ + ) + + if options.long_running_messages and ( + not options.adapter or not options.bot_app_id + ): + logger.error( + "ApplicationOptions.long_running_messages requires an adapter and bot_app_id.", + stack_info=True, + ) + raise ApplicationError( + """ + The `ApplicationOptions.long_running_messages` property is unavailable because + no adapter or `bot_app_id` was configured. + """ + ) + + if options.adapter: + self._adapter = options.adapter + + self._turn_state_factory = ( + options.turn_state_factory + or kwargs.get("turn_state_factory", None) + or partial(TurnState.with_storage, self._options.storage) + ) + + # TODO: decide how to initialize the Authorization (params vs options vs kwargs) + if authorization: + self._auth = authorization + else: + auth_options = { + key: value + for key, value in configuration.items() + if key not in ["storage", "connection_manager", "handlers"] + } + self._auth = Authorization( + storage=self._options.storage, + connection_manager=connection_manager, + handlers=options.authorization_handlers, + **auth_options, + ) + + @property + def adapter(self) -> ChannelServiceAdapter: + """ + The bot's adapter. + """ + + if not self._adapter: + logger.error( + "AgentApplication.adapter(): self._adapter is not configured.", + stack_info=True, + ) + raise ApplicationError( + """ + The AgentApplication.adapter property is unavailable because it was + not configured when creating the AgentApplication. + """ + ) + + return self._adapter + + @property + def auth(self): + """ + The application's authentication manager + """ + if not self._auth: + logger.error( + "AgentApplication.auth(): self._auth is not configured.", + stack_info=True, + ) + raise ApplicationError( + """ + The `AgentApplication.auth` property is unavailable because + no Auth options were configured. + """ + ) + + return self._auth + + @property + def options(self) -> ApplicationOptions: + """ + The application's configured options. + """ + return self._options + + def activity( + self, + activity_type: Union[str, ActivityTypes, List[Union[str, ActivityTypes]]], + *, + auth_handlers: Optional[List[str]] = None, + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.activity("event") + async def on_event(context: TurnContext, state: TurnState): + print("hello world!") + return True + ``` + + #### Args: + - `type`: The type of the activity + """ + + def __selector(context: TurnContext): + return activity_type == context.activity.type + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering activity handler for route handler {func.__name__} with type: {activity_type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def message( + self, + select: Union[str, Pattern[str], List[Union[str, Pattern[str]]]], + *, + auth_handlers: Optional[List[str]] = None, + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.message("hi") + async def on_hi_message(context: TurnContext, state: TurnState): + print("hello!") + return True + + #### Args: + - `select`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if context.activity.type != ActivityTypes.message: + return False + + text = context.activity.text if context.activity.text else "" + if isinstance(select, Pattern): + hits = re.fullmatch(select, text) + return hits is not None + + return text == select + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering message handler for route handler {func.__name__} with select: {select} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def conversation_update( + self, + type: ConversationUpdateTypes, + *, + auth_handlers: Optional[List[str]] = None, + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.conversation_update("channelCreated") + async def on_channel_created(context: TurnContext, state: TurnState): + print("a new channel was created!") + return True + + ``` + + #### Args: + - `type`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if context.activity.type != ActivityTypes.conversation_update: + return False + + if type == "membersAdded": + if isinstance(context.activity.members_added, List): + return len(context.activity.members_added) > 0 + return False + + if type == "membersRemoved": + if isinstance(context.activity.members_removed, List): + return len(context.activity.members_removed) > 0 + return False + + if isinstance(context.activity.channel_data, object): + data = vars(context.activity.channel_data) + return data["event_type"] == type + + return False + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering conversation update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def message_reaction( + self, type: MessageReactionTypes, *, auth_handlers: Optional[List[str]] = None + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.message_reaction("reactionsAdded") + async def on_reactions_added(context: TurnContext, state: TurnState): + print("reactions was added!") + return True + ``` + + #### Args: + - `type`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if context.activity.type != ActivityTypes.message_reaction: + return False + + if type == "reactionsAdded": + if isinstance(context.activity.reactions_added, List): + return len(context.activity.reactions_added) > 0 + return False + + if type == "reactionsRemoved": + if isinstance(context.activity.reactions_removed, List): + return len(context.activity.reactions_removed) > 0 + return False + + return False + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering message reaction handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def message_update( + self, type: MessageUpdateTypes, *, auth_handlers: Optional[List[str]] = None + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.message_update("editMessage") + async def on_edit_message(context: TurnContext, state: TurnState): + print("message was edited!") + return True + ``` + + #### Args: + - `type`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if type == "editMessage": + if ( + context.activity.type == ActivityTypes.message_update + and isinstance(context.activity.channel_data, dict) + ): + data = context.activity.channel_data + return data["event_type"] == type + return False + + if type == "softDeleteMessage": + if ( + context.activity.type == ActivityTypes.message_delete + and isinstance(context.activity.channel_data, dict) + ): + data = context.activity.channel_data + return data["event_type"] == type + return False + + if type == "undeleteMessage": + if ( + context.activity.type == ActivityTypes.message_update + and isinstance(context.activity.channel_data, dict) + ): + data = context.activity.channel_data + return data["event_type"] == type + return False + return False + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering message update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def handoff(self, *, auth_handlers: Optional[List[str]] = None) -> Callable[ + [Callable[[TurnContext, StateT, str], Awaitable[None]]], + Callable[[TurnContext, StateT, str], Awaitable[None]], + ]: + """ + Registers a handler to handoff conversations from one copilot to another. + ```python + # Use this method as a decorator + @app.handoff + async def on_handoff( + context: TurnContext, state: TurnState, continuation: str + ): + print(query) + ``` + """ + + def __selector(context: TurnContext) -> bool: + return ( + context.activity.type == ActivityTypes.invoke + and context.activity.name == "handoff/action" + ) + + def __call( + func: Callable[[TurnContext, StateT, str], Awaitable[None]], + ) -> Callable[[TurnContext, StateT, str], Awaitable[None]]: + async def __handler(context: TurnContext, state: StateT): + if not context.activity.value: + return False + await func(context, state, context.activity.value["continuation"]) + await context.send_activity( + Activity( + type=ActivityTypes.invoke_response, + value=InvokeResponse(status=200), + ) + ) + return True + + logger.debug( + f"Registering handoff handler for route handler {func.__name__} with auth handlers: {auth_handlers}" + ) + + self._routes.append( + Route[StateT](__selector, __handler, True, auth_handlers) + ) + self._routes = sorted(self._routes, key=lambda route: not route.is_invoke) + return func + + return __call + + def on_sign_in_success( + self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] + ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: + """ + Registers a new event listener that will be executed when a user successfully signs in. + + ```python + # Use this method as a decorator + @app.on_sign_in_success + async def sign_in_success(context: TurnContext, state: TurnState): + print("hello world!") + return True + ``` + """ + + if self._auth: + logger.debug( + f"Registering sign-in success handler for route handler {func.__name__}" + ) + self._auth.on_sign_in_success(func) + else: + logger.error( + f"Failed to register sign-in success handler for route handler {func.__name__}", + stack_info=True, + ) + raise ApplicationError( + """ + The `AgentApplication.on_sign_in_success` method is unavailable because + no Auth options were configured. + """ + ) + return func + + def on_sign_in_failure( + self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] + ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: + """ + Registers a new event listener that will be executed when a user fails to sign in. + + ```python + # Use this method as a decorator + @app.on_sign_in_failure + async def sign_in_failure(context: TurnContext, state: TurnState): + print("hello world!") + return True + ``` + """ + + if self._auth: + logger.debug( + f"Registering sign-in failure handler for route handler {func.__name__}" + ) + self._auth.on_sign_in_failure(func) + else: + logger.error( + f"Failed to register sign-in failure handler for route handler {func.__name__}", + stack_info=True, + ) + raise ApplicationError( + """ + The `AgentApplication.on_sign_in_failure` method is unavailable because + no Auth options were configured. + """ + ) + return func + + def error( + self, func: Callable[[TurnContext, Exception], Awaitable[None]] + ) -> Callable[[TurnContext, Exception], Awaitable[None]]: + """ + Registers an error handler that will be called anytime + the app throws an Exception + + ```python + # Use this method as a decorator + @app.error + async def on_error(context: TurnContext, err: Exception): + print(err.message) + ``` + """ + + logger.debug(f"Registering the error handler {func.__name__} ") + self._error = func + + if self._adapter: + logger.debug( + f"Registering for adapter {self._adapter.__class__.__name__} the error handler {func.__name__} " + ) + self._adapter.on_turn_error = func + + return func + + def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): + """ + Custom Turn State Factory + """ + logger.debug(f"Setting custom turn state factory: {func.__name__}") + self._turn_state_factory = func + return func + + async def _handle_flow_response(self, context: TurnContext, flow_response: FlowResponse) -> None: + + flow_state: FlowState = flow_response.flow_state + in_flow_activity = flow_response.in_flow_activity + + if in_flow_activity: + await context.send_activity(in_flow_activity) + + if flow_state.tag == FlowStateTag.BEGIN: + # Create the OAuth card + o_card: Attachment = CardFactory.oauth_card( + OAuthCard( + text=self.messages_configuration.get("card_title", "Sign in"), + connection_name=flow_state.connection, + buttons=[ + CardAction( + title=self.messages_configuration.get("button_text", "Sign in"), + type=ActionTypes.signin, + value=signing_resource.sign_in_link, + channel_data=None, + ) + ], + token_exchange_resource=signing_resource.token_exchange_resource, + token_post_resource=signing_resource.token_post_resource, + ) + ) + + # Send the card to the user + await context.send_activity(MessageFactory.attachment(o_card)) + elif flow_state.tag == FlowStateTag.FAILURE: + if flow_state.reached_max_retries(): + await context.send_activity( + MessageFactory.text( + self.messages_configuration.get( + "max_retries_reached_messages", + "Sign-in failed. Please try again later.", + ) + ) + ) + elif flow_state.is_expired(): + await context.send_activity( + MessageFactory.text( + self.messages_configuration.get( + "session_expired_messages", + "Sign-in session expired. Please try again.", + ) + )) + else: + logger.warning("Sign-in flow failed for unknown reasons.") + await context.send_activity("Sign-in failed. Please try again.") + + async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: + + prev_flow_state = await self._auth.get_active_flow_state(context) + if self._auth and prev_flow_state: + + logger.debug("Sign-in flow is active for context: %s", context.activity.id) + + flow_response: FlowResponse = await self._auth.begin_or_continue_flow( + context, turn_state, prev_flow_state.handler_id + ) + + await self._handle_flow_response(context, flow_response) + + new_flow_state: FlowState = flow_response.flow_state + token_response: TokenResponse = new_flow_state.token_response + saved_activity: Activity = new_flow_state.continuation_activity.model_copy() + + if token_response and token_response.token: + new_context = copy(context) + new_context.activity = saved_activity + logger.info( + "Resending continuation activity %s", saved_activity.text + ) + await self.on_turn(new_context) + turn_state.delete_value(Authorization.SIGN_IN_STATE_KEY) # robrandao: TODOTODO + await turn_state.save(context) + return True + + return False + + async def on_turn(self, context: TurnContext): + logger.debug( + f"AgentApplication.on_turn(): Processing turn for context: {context.activity.id}" + ) + await self._start_long_running_call(context, self._on_turn) + + async def _on_turn(self, context: TurnContext): + # robrandao: TODO + try: + await self._start_typing(context) + + self._remove_mentions(context) + + logger.debug("Initializing turn state") + turn_state = await self._initialize_state(context) + + if await self._on_turn_auth_intercept(context, turn_state): + return + + logger.debug("Running before turn middleware") + if not await self._run_before_turn_middleware(context, turn_state): + return + + logger.debug("Running file downloads") + await self._handle_file_downloads(context, turn_state) + + logger.debug("Running activity handlers") + await self._on_activity(context, turn_state) + + logger.debug("Running after turn middleware") + if not await self._run_after_turn_middleware(context, turn_state): + await turn_state.save(context) + return + except ApplicationError as err: + logger.error( + f"An application error occurred in the AgentApplication: {err}", + exc_info=True, + ) + await self._on_error(context, err) + finally: + logger.debug("Stopping typing indicator") + self.typing.stop() + + async def _start_typing(self, context: TurnContext): + if self._options.start_typing_timer: + logger.debug("Starting typing indicator for context") + await self.typing.start(context) + + def _remove_mentions(self, context: TurnContext): + if ( + self.options.remove_recipient_mention + and context.activity.type == ActivityTypes.message + ): + context.activity.text = context.remove_recipient_mention(context.activity) + + @staticmethod + def parse_env_vars_configuration(vars: Dict[str, Any]) -> dict: + """ + Parses environment variables and returns a dictionary with the relevant configuration. + """ + result = {} + for key, value in vars.items(): + levels = key.split("__") + current_level = result + last_level = None + for next_level in levels: + if next_level not in current_level: + current_level[next_level] = {} + last_level = current_level + current_level = current_level[next_level] + logger.debug(f"Using environment variable '{key}'") + last_level[levels[-1]] = value + + return { + "AGENT_APPLICATION": result["AGENT_APPLICATION"], + "COPILOT_STUDIO_AGENT": result["COPILOT_STUDIO_AGENT"], + "CONNECTIONS": result["CONNECTIONS"], + "CONNECTIONS_MAP": result["CONNECTIONS_MAP"], + } + + async def _initialize_state(self, context: TurnContext) -> StateT: + if self._turn_state_factory: + logger.debug("Using custom turn state factory") + turn_state = self._turn_state_factory() + else: + logger.debug("Using default turn state factory") + turn_state = TurnState.with_storage(self._options.storage) + await turn_state.load(context, self._options.storage) + + turn_state = cast(StateT, turn_state) + + logger.debug("Loading turn state from storage") + await turn_state.load(context, self._options.storage) + turn_state.temp.input = context.activity.text + return turn_state + + async def _run_before_turn_middleware(self, context: TurnContext, state: StateT): + for before_turn in self._internal_before_turn: + is_ok = await before_turn(context, state) + if not is_ok: + await state.save(context, self._options.storage) + return False + return True + + async def _handle_file_downloads(self, context: TurnContext, state: StateT): + if self._options.file_downloaders and len(self._options.file_downloaders) > 0: + input_files = state.temp.input_files if state.temp.input_files else [] + for file_downloader in self._options.file_downloaders: + logger.info( + f"Using file downloader: {file_downloader.__class__.__name__}" + ) + files = await file_downloader.download_files(context) + input_files.extend(files) + state.temp.input_files = input_files + + def _contains_non_text_attachments(self, context: TurnContext): + non_text_attachments = filter( + lambda a: not a.content_type.startswith("text/html"), + context.activity.attachments, + ) + return len(list(non_text_attachments)) > 0 + + async def _run_after_turn_middleware(self, context: TurnContext, state: StateT): + for after_turn in self._internal_after_turn: + is_ok = await after_turn(context, state) + if not is_ok: + await state.save(context, self._options.storage) + return False + return True + + async def _on_activity(self, context: TurnContext, state: StateT): + for route in self._routes: + if route.selector(context): + if not route.auth_handlers: + await route.handler(context, state) + else: + sign_in_complete = False + for auth_handler_id in route.auth_handlers: + flow_response: FlowResponse = await self._auth.begin_or_continue_flow( + context, state, auth_handler_id + ) + await self._handle_flow_response(context, flow_response.in_flow_activity) + sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE + if not sign_in_complete: + break + + if sign_in_complete: + await route.handler(context, state) + return + logger.warning( + f"No route found for activity type: {context.activity.type} with text: {context.activity.text}" + ) + + async def _start_long_running_call( + self, context: TurnContext, func: Callable[[TurnContext], Awaitable] + ): + if ( + self._adapter + and ActivityTypes.message == context.activity.type + and self._options.long_running_messages + ): + logger.debug( + f"Starting long running call for context: {context.activity.id} with function: {func.__name__}" + ) + return await self._adapter.continue_conversation( + reference=context.get_conversation_reference(context.activity), + callback=func, + bot_app_id=self.options.bot_app_id, + ) + + return await func(context) + + async def _on_error(self, context: TurnContext, err: ApplicationError) -> None: + if self._error: + logger.info( + f"Calling error handler {self._error.__name__} for error: {err}" + ) + return await self._error(context, err) + + logger.error( + f"An error occurred in the AgentApplication: {err}", + exc_info=True, + ) + logger.error(err) + raise err diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py index c8d125cb..e0d871de 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py @@ -9,7 +9,7 @@ from logging import Logger from typing import Callable, List, Optional -from microsoft.agents.hosting.core.app.oauth.authorization import AuthorizationHandlers +from microsoft.agents.hosting.core.app.oauth import AuthHandler from microsoft.agents.hosting.core.storage import Storage # from .auth import AuthOptions @@ -84,7 +84,7 @@ class ApplicationOptions: If not provided, the default `TurnState` will be used. """ - authorization_handlers: Optional[AuthorizationHandlers] = None + authorization_handlers: Optional[dict[str, AuthHandler]] = None """ Optional. Authorization handler for OAuth flows. If not provided, no OAuth flows will be supported. diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/__authorization_test.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/__authorization_test.py deleted file mode 100644 index 63e68f3e..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/__authorization_test.py +++ /dev/null @@ -1,340 +0,0 @@ -import datetime - -import pytest -from pytest_lazyfixture import lazy_fixture - -from microsoft.agents.activity import ( - TokenResponse -) -from microsoft.agents.hosting.core import ( - Authorization, - MemoryStorage, - FlowStorageClient, - FlowState, - FlowErrorTag, - FlowStateTag, - FlowResponse -) -from microsoft.agents.hosting.core.app.oauth.auth_flow import ( - AuthFlow -) -from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline - -def mock_flow(mocker, flow_states: list[FlowState]): - flow = mocker.Mock(spec=AuthFlow) - flow.begin_or_continue_flow = mocker.AsyncMock( - side_effect=flow_states - ) - return flow - -STORAGE_SAMPLE_DICT = { - "user_id": "123", - "session_id": "abc", - "auth/channel_id/user_id/expired": FlowState( - id="expired", - expires=expired_time, - attempts_remaining=1, - tag=FlowStateTag.CONTINUE - ), - "auth/teams_id/Bob/no_retries": FlowState( - id="no_retries", - expires=valid_time, - attempts_remaining=0, - tag=FlowStateTag.FAILURE - ), - "auth/channel/Alice/begin": FlowState( - id="begin", - expired=valid_time, - attempts_remaining=3, - tag=FlowStateTag.BEGIN - ), - "auth/channel/Alice/continue": FlowState( - id="continue", - expires=valid_time, - attempts_remaining=2 - tag=FlowStateTag.CONTINUE - ), - "auth/channel/Alice/expired_and_retries": FlowState( - id="expired_and_retries" - expires=expired_time, - attempts_remaining=0 - tag=FlowStateTag.FAILURE - ), - "auth/channel/Alice/not_started": FlowState( - id="not_started", - tag=FlowStateTag.NOT_STARTED - ) -} - -class TestAuthorization: - - def build_context(self, mocker, channel_id, from_property_id): - turn_context = mocker.Mock() - turn_context.activity.channel_id = channel_id - turn_context.activity.from_property.id = from_property_id - return turn_context - - @pytest.fixture - - @pytest.fixture - def context(self, mocker): - return self.build_context(mocker, "__channel_id", "__user_id") - - @pytest.fixture - def valid_time(self): - return datetime.datetime.now() + 10000 - - @pytest.fixture - def expired_time(self): - return datetime.datetime.now() - - @pytest.fixture - def m_storage(self, mocker): - return mocker.Mock(spec=MemoryStorage) - - @pytest.fixture - def m_connection_manager(self, mocker): - return mocker.Mock(spec=ConnectionManager) - - @pytest.fixture - def auth_handler_ids(self): - return ["expired", "no_retries", "begin", "continue", "expired_and_retries", "not_started"] - - @pytest.fixture - def auth_handlers(self, mocker, auth_handler_ids): - return { - auth_handler_id: create_test_auth_handler(f"test-{auth_handler_id}") for auth_handler_id in auth_handler_ids - } - - @pytest.fixture - def storage(self, valid_time, expired_time): - return MemoryStorage(STORAGE_SAMPLE_DICT) - - @pytest.fixture - def connection_manager(self): - pass - - @pytest.fixture - def auth_handlers(self): - pass - - @pytest.fixture - def auth(self, storage, connection_manager, auth_handlers): - return Authorization( - storage, - connection_manager, - auth_handlers, - auto_signin=True - ) - - @pytest.fixture - def storage(self, mocker): - return MemoryStorage({ - - }) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth, context, auth_handler_id", - [ - ("auth", lazy_fixture("context"), ""), - ("auth", None, "handler"), - ("auth", None, "") - ("auth", lazy_fixture("context", "missing_handler")) - ], - indirect=["auth"] - ) - async def test_open_flow_value_error(self, auth, context, auth_handler_id): - with pytest.raises(ValueError): - async with auth.open_flow(context, auth_handler_id): - pass - - # async def test_open_flow_storage_readonly_storage_access(self, mocker, context, m_storage, m_connection_manager, m_auth_handlers): - # # setup - # m_storage.read.return_value = FlowState() - # auth = Authorization( - # m_storage, - # m_connection_manager, - # m_auth_handlers - # ) - - # # code - # async with auth.open_flow(context, "handler", readonly=True) as flow: - # actual_init_flow_state = flow.flow_state - - # # verify - # assert actual_init_flow_state is m_storage.read.return_value - # assert not m_storage.write.called - # assert not m_storage.delete.called - - # async def test_open_flow_storage_unchanged_not_readonly_storage_access(self, context, m_storage, m_connection_manager, m_auth_handlers): - # # setup - # m_storage.read.return_value = FlowState() - # auth = Authorization( - # m_storage, - # m_connection_manager, - # m_auth_handlers - # ) - - # # code - # async with auth.open_flow(context, "handler", readonly=False) as flow: - # # if no change is made to the flow state, then storage should not be updated - # actual_init_flow_state = flow.flow_state - - # # verify - # assert actual_init_flow_state is m_storage.read.return_value - # assert not m_storage.write.called - # assert not m_storage.delete.called - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", - [ - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), - ] - ) - async def test_open_flow_readonly_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): - # setup - storage = MemoryStorage(STORAGE_SAMPLE_DICT) - baseline = StorageBaseline(STORAGE_SAMPLE_DICT) - auth = Authorization( - storage, - connection_manager, - auth_handlers - ) - context = self.build_context(mocker, channel_id, from_property_id) - storage_client = FlowStorageClient(context, storage) - key = storage_client.key(auth_handler_id) - expected_init_flow_state = storage.read(key, FlowState) - - # code - async with auth.open_flow(context, "handler", readonly=True) as flow: - actual_init_flow_state = flow.flow_state.copy() - flow.flow_state.id = "garbage" - flow.flow_state.tag = FlowStateTag.FAILURE - flow.flow_state.expires = 0 - flow.flow_state.attempts_remaining = -1 - actual_final_flow_state = await storage.read([key], FlowState)[key] - - # verify - expected_final_flow_state = baseline.read(key, FlowState) - assert actual_init_flow_state == expected_init_flow_state - assert actual_final_flow_state == expected_final_flow_state - assert await baseline.equals(storage) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mocker, connection_manager, channel_id, from_property_id, auth_handler_id", - [ - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel_id", "user_id", "expired"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "teams_id", "Bob", "no_retries"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "begin"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "continue"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "expired_and_retries"), - (lazy_fixture("mocker"), lazy_fixture("connection_manager"), "channel", "Alice", "not_started"), - ] - ) - async def test_open_flow_storage_run(self, mocker, connection_manager, channel_id, from_property_id, auth_handler_id): - # setup - storage = MemoryStorage(STORAGE_SAMPLE_DICT) - baseline = StorageBaseline(STORAGE_SAMPLE_DICT) - auth = Authorization( - storage, - connection_manager, - auth_handlers - ) - context = self.build_context(mocker, channel_id, from_property_id) - storage_client = FlowStorageClient(context, storage) - key = storage_client.key(auth_handler_id) - expected_init_flow_state = storage.read(key, FlowState) - - # code - async with auth.open_flow(context, "handler") as flow: - actual_init_flow_state = flow.flow_state.copy() - flow.flow_state.id = "garbage" - flow.flow_state.tag = FlowStateTag.FAILURE - flow.flow_state.expires = 0 - flow.flow_state.attempts_remaining = -1 - - # verify - baseline.write({ - "auth/channel/Alice/continue": flow.flow_state - }) - expected_final_flow_state = baseline.read(key, FlowState) - assert await baseline.equals(storage) - assert actual_init_flow_state == expected_init_flow_state - assert flow.flow_state == expected_final_flow_state - - @pytest.mark.asyncio - async def test_get_token(self, mocker, m_storage): - m_storage.read.return_value = FlowState( - id="auth_handler", - tag=FlowStateTag.ACTIVE, - expires=3600, - attempts_remaining=3 - ) - expected = TokenResponse( - access_token="access_token", - refresh_token="refresh_token", - expires_in=3600 - ) - mock_flow = mocker.AsyncMock() - mock_flow.get_user_token.return_value = expected - mocker.patch.object("OAuthFlow", "get_token", return_value=expected) - mocker.patch.object("OAuthFlow", "__init__", return_value=mock_flow) - - assert await auth.get_token("auth_handler") is expected - assert mock_flow.get_user_token.called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth, context, auth_handler_id", - [ - (lazy_fixture("auth"), lazy_fixture("context"), "missing-handler"), - (lazy_fixture("auth"), lazy_fixture("context"), ""), - (lazy_fixture("auth"), None, "handler") - ] - ) - async def test_get_token_error(self, auth, context, auth_handler_id): - with pytest.raises(ValueError): - await auth.get_token(context, auth_handler_id) - - @pytest.fixture - def valid_token_response(self): - return TokenResponse( - connection_name="connection", - token="token" - ) - - @pytest.fixture - def invalid_exchange_token(self): - token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - return TokenResponse( - connection_name="connection" - token=token - ) - - @pytest.mark.asyncio - @pytest.mark.parametrize - async def test_exchange_token(self, mocker, auth): - - mocker.patch.object("OAuthFlow", - get_user_token=mocker.AsyncMock(return_value=TokenResponse( - access_token="access_token", - refresh_token="refresh_token", - expires_in=3600 - )) - ) - - - - - - @pytest.mark.asyncio - async def test_exchange_token(self): - pass \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/conftest.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/tests/conftest.py deleted file mode 100644 index e69de29b..00000000 diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py similarity index 100% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/__init__.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py similarity index 97% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/auth_handler.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py index 8ad53bce..b7afd9b1 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/auth_handler.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py @@ -1,46 +1,46 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import logging -from typing import Dict - -logger = logging.getLogger(__name__) - -class AuthHandler: - """ - Interface defining an authorization handler for OAuth flows. - """ - - def __init__( - self, - name: str = None, - title: str = None, - text: str = None, - abs_oauth_connection_name: str = None, - obo_connection_name: str = None, - **kwargs, - ): - """ - Initializes a new instance of AuthHandler. - - Args: - name: The name of the OAuth connection. - auto: Whether to automatically start the OAuth flow. - title: Title for the OAuth card. - text: Text for the OAuth button. - """ - self.name = name or kwargs.get("NAME") - self.title = title or kwargs.get("TITLE") - self.text = text or kwargs.get("TEXT") - self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( - "AZUREBOTOAUTHCONNECTIONNAME" - ) - self.obo_connection_name = obo_connection_name or kwargs.get( - "OBOCONNECTIONNAME" - ) - logger.debug( - f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" - ) - -# # Type alias for authorization handlers dictionary -AuthorizationHandlers = Dict[str, AuthHandler] +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import logging +from typing import Dict + +logger = logging.getLogger(__name__) + +class AuthHandler: + """ + Interface defining an authorization handler for OAuth flows. + """ + + def __init__( + self, + name: str = None, + title: str = None, + text: str = None, + abs_oauth_connection_name: str = None, + obo_connection_name: str = None, + **kwargs, + ): + """ + Initializes a new instance of AuthHandler. + + Args: + name: The name of the OAuth connection. + auto: Whether to automatically start the OAuth flow. + title: Title for the OAuth card. + text: Text for the OAuth button. + """ + self.name = name or kwargs.get("NAME") + self.title = title or kwargs.get("TITLE") + self.text = text or kwargs.get("TEXT") + self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( + "AZUREBOTOAUTHCONNECTIONNAME" + ) + self.obo_connection_name = obo_connection_name or kwargs.get( + "OBOCONNECTIONNAME" + ) + logger.debug( + f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" + ) + +# # Type alias for authorization handlers dictionary +AuthorizationHandlers = Dict[str, AuthHandler] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py similarity index 90% rename from libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/authorization.py rename to libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 10f5b17d..04375c91 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/auth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -1,415 +1,419 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from __future__ import annotations -import logging -import jwt -from typing import Dict, Optional, Callable, Awaitable, AsyncIterator -from collections.abc import Iterable -from contextlib import asynccontextmanager - -from microsoft.agents.hosting.core.authorization import ( - Connections, - AccessTokenProviderBase, -) -from microsoft.agents.hosting.core.storage import Storage -from microsoft.agents.activity import TokenResponse -from microsoft.agents.hosting.core.connector.client import UserTokenClient - -from ...turn_context import TurnContext -from ...oauth import ( - OAuthFlow, - FlowResponse, - FlowState, - FlowStateTag, - FlowStorageClient -) -from ..state.turn_state import TurnState -from .auth_handler import AuthHandler - -logger = logging.getLogger(__name__) - - -class Authorization: - """ - Class responsible for managing authorization and OAuth flows. - Handles multiple OAuth providers and manages the complete authentication lifecycle. - """ - - def __init__( - self, - storage: Storage, - connection_manager: Connections, - auth_handlers: dict[str, AuthHandler] = None, - auto_signin: bool = None, - **kwargs, - ): - """ - Creates a new instance of Authorization. - - Args: - storage: The storage system to use for state management. - auth_handlers: Configuration for OAuth providers. - - Raises: - ValueError: If storage is None or no auth handlers are provided. - """ - if not storage: - raise ValueError("Storage is required for Authorization") - # if not auth_handlers: - # raise ValueError("At least one AuthHandler must be provided") - - # user_state = UserState(storage) - - self.__storage = storage - self.__connection_manager = connection_manager - - auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( - "USERAUTHORIZATION", {} - ) - - # self.__auto_signin = ( - # auto_signin - # if auto_signin is not None - # else auth_configuration.get("AUTOSIGNIN", False) - # ) - - handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") - if not auth_handlers and handlers_config: - auth_handlers = { - handler_name: AuthHandler( - name=handler_name, **config.get("SETTINGS", {}) - ) - for handler_name, config in handlers_config.items() - } - - self.__auth_handlers = auth_handlers or {} - self.__sign_in_handler: Optional[ - Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = None - self.__sign_in_failed_handler: Optional[ - Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = None - - # # Configure each auth handler - # for auth_handler in self.__auth_handlers.values(): - # # Create OAuth flow with configuration - # messages_config = {} - # if auth_handler.title: - # ["card_title"] = auth_handler.title - # if auth_handler.text: - # messages_config["button_text"] = auth_handler.text - - # logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") - # auth_handler.flow = AuthFlow( - # storage=storage, - # abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, - # messages_configuration=messages_config if messages_config else None, - # ) - - def __ids_from_context(self, context: TurnContext) -> tuple[str, str]: - """Checks and returns IDs necessary to load a new or existing flow. - - Raises a ValueError if channel ID or user ID are missing. - """ - if ( - not context.activity.channel_id or - not context.activity.from_property or - not context.activity.from_property.id - ): - raise ValueError("Channel ID and User ID are required") - - return context.activity.channel_id, context.activity.from_property.id - - async def __load_flow( - self, - context: TurnContext, - auth_handler_id: str = "" - ) -> tuple[OAuthFlow, FlowStorageClient, FlowState]: - """Loads the OAuth flow for a specific auth handler. - - Args: - context: The context object for the current turn. - auth_handler_id: The ID of the auth handler to use. - - Returns: - The OAuthFlow returned corresponds to the flow associated with the - chosen handler, and the channel and user info found in the context. - - The FlowStorageClient corresponds to the channel and user info. - The FlowState returned is the flow state for the given channel/user/handler - triple at the time of creating the flow. - """ - user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) # robrandao: TODO - - # resolve handler id - auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) - auth_handler_id = auth_handler.id - - channel_id, user_id = self.__ids_from_context(context) - - # try to load existing state - flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) - flow_state: FlowState = await flow_storage_client.read(auth_handler_id) - - if not flow_state: - flow_state = FlowState( - channel_id=channel_id, - user_id=user_id, - auth_handler_id=auth_handler_id, - abs_oauth_connection_name=auth_handler.abs_oauth_connection_name - ) - - flow = OAuthFlow(flow_state, user_token_client) - return flow, flow_storage_client, flow_state - - @asynccontextmanager - async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[OAuthFlow]: - """Loads an OAuth flow and saves changes the changes to storage if any are made. - - Args: - context: The context object for the current turn. - auth_handler_id: ID of the auth handler to use. - - Yields: - OAuthFlow: - The OAuthFlow instance loaded from storage or newly created - if not yet present in storage. - """ - if not context: - raise ValueError("context is required") - - flow, flow_storage_client, init_flow_state = self.__load_flow(context, auth_handler_id) - yield flow - - # persist state - new_flow_state = flow.flow_state - if new_flow_state != init_flow_state: - flow_storage_client.write(new_flow_state) - - async def get_token( - self, context: TurnContext, auth_handler_id: str - ) -> TokenResponse: - """ - Gets the token for a specific auth handler. - - Args: - context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - """ - async with self.open_flow(context, auth_handler_id) as flow: - return await flow.get_user_token(context) - - async def exchange_token( - self, - context: TurnContext, - scopes: list[str], - auth_handler_id: Optional[str] = None, - ) -> TokenResponse: - """ - Exchanges a token for another token with different scopes. - - Args: - context: The context object for the current turn. - scopes: The scopes to request for the new token. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - """ - async with self.open_flow(context, auth_handler_id) as flow: - token_response = await flow.get_user_token() - - if token_response and self.__is_exchangeable(token_response.token): - return await self.__handle_obo(token_response.token, scopes, auth_handler_id) - - return TokenResponse() - - # auth_handler = self.resolver_handler(auth_handler_id) - # if not auth_handler.flow: - # logger.error("OAuth flow is not configured for the auth handler") - # raise ValueError("OAuth flow is not configured for the auth handler") - - # token_response = await auth_handler.flow.get_user_token(context) - - # if self.__is_exchangeable(token_response.token if token_response else None): - # return await self.__handle_obo(token_response.token, scopes, auth_handler_id) - - # return token_response - - def __is_exchangeable(self, token: str) -> bool: - """ - Checks if a token is exchangeable (has api:// audience). - - Args: - token: The token to check. - - Returns: - True if the token is exchangeable, False otherwise. - """ - try: - # Decode without verification to check the audience - payload = jwt.decode(token, options={"verify_signature": False}) - aud = payload.get("aud") - return isinstance(aud, str) and aud.startswith("api://") - except Exception: - logger.exception("Failed to decode token to check audience") - return False - - async def __handle_obo( - self, token: str, scopes: list[str], handler_id: str = None - ) -> TokenResponse: - """ - Handles On-Behalf-Of token exchange. - - Args: - context: The context object for the current turn. - token: The original token. - scopes: The scopes to request. - - Returns: - The new token response. - - """ - auth_handler = self.resolve_handler(handler_id) - token_provider: AccessTokenProviderBase = self.__connection_manager.get_connection( - auth_handler.obo_connection_name - ) - - logger.info("Attempting to exchange token on behalf of user") - new_token = await token_provider.aquire_token_on_behalf_of( - scopes=scopes, - user_assertion=token, - ) - return TokenResponse( - token=new_token, - scopes=scopes, # Expiration can be set based on the token provider's response - ) - - async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: - """Gets the first active flow state for the current context.""" - channel_id, user_id = self.__ids_from_context(context) - flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) - for auth_handler_id in self.__auth_handlers.keys(): - flow_state = await flow_storage_client.read(auth_handler_id) - if flow_state.is_active(): - return flow_state - return None - - async def begin_or_continue_flow( - self, - context: TurnContext, - turn_state: TurnState, - auth_handler_id: str, - sec_route: bool = True, - ) -> FlowResponse: - """ - Begins or continues an OAuth flow. - - Args: - context: The context object for the current turn. - state: The state object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - """ - # robrandao: TODO -> is_started_from_route and sec_route - - async with self.open_flow(context, auth_handler_id) as flow: - flow_response: FlowResponse = await flow.begin_or_continue_flow(context) - - flow_state: FlowState = flow_response.flow_state - - if flow_state.tag == FlowStateTag.COMPLETE: - self.__sign_in_success_handler(context, turn_state, flow_state.handler.id) - elif flow_state.tag == FlowStateTag.FAILURE: - self.__sign_in_failure_handler(context, turn_state, flow_state.handler.id, err) - - return flow_response - - def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: - """ - Resolves the auth handler to use based on the provided ID. - - Args: - auth_handler_id: Optional ID of the auth handler to resolve, defaults to first handler. - - Returns: - The resolved auth handler. - """ - if auth_handler_id: - if auth_handler_id not in self.__auth_handlers: - logger.error(f"Auth handler '{auth_handler_id}' not found") - raise ValueError(f"Auth handler '{auth_handler_id}' not found") - return self.__auth_handlers[auth_handler_id] - - # Return the first handler if no ID specified - return next(iter(self.__auth_handlers.values)) - - async def __sign_out( - self, - context: TurnContext, - auth_handler_ids: Iterable[str], - ) -> None: - """Signs out from the specified auth handlers. - - Args: - context: The context object for the current turn. - auth_handler_ids: List of auth handler IDs to sign out from. - - Deletes the associated flow states from storage. - """ - for auth_handler_id in auth_handler_ids: - flow, flow_storage_client, initial_flow_state = self.__load_flow(context, auth_handler_id) - if initial_flow_state: - logger.info(f"Signing out from handler: {auth_handler_id}") - await flow.sign_out() - flow_storage_client.delete(auth_handler_id) - - async def sign_out( - self, - context: TurnContext, - auth_handler_id: Optional[str] = None, - ) -> None: - """ - Signs out the current user. - This method clears the user's token and resets the OAuth state. - - Args: - context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use for sign out. If None, - signs out from all the handlers. - - Deletes the associated flow state(s) from storage. - """ - if auth_handler_id: - self.__sign_out(context, [auth_handler_id]) - else: - self.__sign_out(context, self.__auth_handlers.keys()) - - def on_sign_in_success( - self, - handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], - ) -> None: - """ - Sets a handler to be called when sign-in is successfully completed. - - Args: - handler: The handler function to call on successful sign-in. - """ - self.__sign_in_success_handler = handler - - def on_sign_in_failure( - self, - handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], - ) -> None: - """ - Sets a handler to be called when sign-in fails. - Args: - handler: The handler function to call on sign-in failure. - """ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations +import logging +import jwt +from typing import Dict, Optional, Callable, Awaitable, AsyncIterator +from collections.abc import Iterable +from contextlib import asynccontextmanager + +from microsoft.agents.hosting.core.authorization import ( + Connections, + AccessTokenProviderBase, +) +from microsoft.agents.hosting.core.storage import Storage +from microsoft.agents.activity import TokenResponse +from microsoft.agents.hosting.core.connector.client import UserTokenClient + +from ...turn_context import TurnContext +from ...oauth import ( + OAuthFlow, + FlowResponse, + FlowState, + FlowStateTag, + FlowStorageClient +) +from ..state.turn_state import TurnState +from .auth_handler import AuthHandler + +logger = logging.getLogger(__name__) + + +class Authorization: + """ + Class responsible for managing authorization and OAuth flows. + Handles multiple OAuth providers and manages the complete authentication lifecycle. + """ + + def __init__( + self, + storage: Storage, + connection_manager: Connections, + auth_handlers: dict[str, AuthHandler] = None, + auto_signin: bool = None, + **kwargs, + ): + """ + Creates a new instance of Authorization. + + Args: + storage: The storage system to use for state management. + auth_handlers: Configuration for OAuth providers. + + Raises: + ValueError: If storage is None or no auth handlers are provided. + """ + if not storage: + raise ValueError("Storage is required for Authorization") + # if not auth_handlers: + # raise ValueError("At least one AuthHandler must be provided") + + # user_state = UserState(storage) + + self.__storage = storage + self.__connection_manager = connection_manager + + auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( + "USERAUTHORIZATION", {} + ) + + # self.__auto_signin = ( + # auto_signin + # if auto_signin is not None + # else auth_configuration.get("AUTOSIGNIN", False) + # ) + + handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") + if not auth_handlers and handlers_config: + auth_handlers = { + handler_name: AuthHandler( + name=handler_name, **config.get("SETTINGS", {}) + ) + for handler_name, config in handlers_config.items() + } + + self.__auth_handlers = auth_handlers or {} + self.__sign_in_success_handler: Optional[ + Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] + ] = None + self.__sign_in_failure_handler: Optional[ + Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] + ] = None + + # # Configure each auth handler + # for auth_handler in self.__auth_handlers.values(): + # # Create OAuth flow with configuration + # messages_config = {} + # if auth_handler.title: + # ["card_title"] = auth_handler.title + # if auth_handler.text: + # messages_config["button_text"] = auth_handler.text + + # logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") + # auth_handler.flow = AuthFlow( + # storage=storage, + # abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, + # messages_configuration=messages_config if messages_config else None, + # ) + + def __ids_from_context(self, context: TurnContext) -> tuple[str, str]: + """Checks and returns IDs necessary to load a new or existing flow. + + Raises a ValueError if channel ID or user ID are missing. + """ + if ( + not context.activity.channel_id or + not context.activity.from_property or + not context.activity.from_property.id + ): + raise ValueError("Channel ID and User ID are required") + + return context.activity.channel_id, context.activity.from_property.id + + async def __load_flow( + self, + context: TurnContext, + auth_handler_id: str = "" + ) -> tuple[OAuthFlow, FlowStorageClient, FlowState]: + """Loads the OAuth flow for a specific auth handler. + + Args: + context: The context object for the current turn. + auth_handler_id: The ID of the auth handler to use. + + Returns: + The OAuthFlow returned corresponds to the flow associated with the + chosen handler, and the channel and user info found in the context. + + The FlowStorageClient corresponds to the channel and user info. + The FlowState returned is the flow state for the given channel/user/handler + triple at the time of creating the flow. + """ + user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) + + # resolve handler id + auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) + auth_handler_id = auth_handler.name + + channel_id, user_id = self.__ids_from_context(context) + + ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims["aud"] + + # try to load existing state + flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) + flow_state: FlowState = await flow_storage_client.read(auth_handler_id) + + if not flow_state: + flow_state = FlowState( + channel_id=channel_id, + user_id=user_id, + auth_handler_id=auth_handler_id, + connection=auth_handler.abs_oauth_connection_name, + ms_app_id=ms_app_id + ) + await flow_storage_client.write(flow_state) + + flow = OAuthFlow(flow_state, user_token_client) + return flow, flow_storage_client, flow_state + + @asynccontextmanager + async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[OAuthFlow]: + """Loads an OAuth flow and saves changes the changes to storage if any are made. + + Args: + context: The context object for the current turn. + auth_handler_id: ID of the auth handler to use. + + Yields: + OAuthFlow: + The OAuthFlow instance loaded from storage or newly created + if not yet present in storage. + """ + if not context: + raise ValueError("context is required") + + flow, flow_storage_client, init_flow_state = await self.__load_flow(context, auth_handler_id) + yield flow + + # persist state + new_flow_state = flow.flow_state + if new_flow_state != init_flow_state: + await flow_storage_client.write(new_flow_state) + + async def get_token( + self, context: TurnContext, auth_handler_id: str + ) -> TokenResponse: + """ + Gets the token for a specific auth handler. + + Args: + context: The context object for the current turn. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The token response from the OAuth provider. + """ + async with self.open_flow(context, auth_handler_id) as flow: + return await flow.get_user_token(context) + + async def exchange_token( + self, + context: TurnContext, + scopes: list[str], + auth_handler_id: Optional[str] = None, + ) -> TokenResponse: + """ + Exchanges a token for another token with different scopes. + + Args: + context: The context object for the current turn. + scopes: The scopes to request for the new token. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The token response from the OAuth provider. + """ + async with self.open_flow(context, auth_handler_id) as flow: + token_response = await flow.get_user_token() + + if token_response and self.__is_exchangeable(token_response.token): + return await self.__handle_obo(token_response.token, scopes, auth_handler_id) + + return TokenResponse() + + # auth_handler = self.resolver_handler(auth_handler_id) + # if not auth_handler.flow: + # logger.error("OAuth flow is not configured for the auth handler") + # raise ValueError("OAuth flow is not configured for the auth handler") + + # token_response = await auth_handler.flow.get_user_token(context) + + # if self.__is_exchangeable(token_response.token if token_response else None): + # return await self.__handle_obo(token_response.token, scopes, auth_handler_id) + + # return token_response + + def __is_exchangeable(self, token: str) -> bool: + """ + Checks if a token is exchangeable (has api:// audience). + + Args: + token: The token to check. + + Returns: + True if the token is exchangeable, False otherwise. + """ + try: + # Decode without verification to check the audience + payload = jwt.decode(token, options={"verify_signature": False}) + aud = payload.get("aud") + return isinstance(aud, str) and aud.startswith("api://") + except Exception: + logger.exception("Failed to decode token to check audience") + return False + + async def __handle_obo( + self, token: str, scopes: list[str], handler_id: str = None + ) -> TokenResponse: + """ + Handles On-Behalf-Of token exchange. + + Args: + context: The context object for the current turn. + token: The original token. + scopes: The scopes to request. + + Returns: + The new token response. + + """ + auth_handler = self.resolve_handler(handler_id) + token_provider: AccessTokenProviderBase = self.__connection_manager.get_connection( + auth_handler.obo_connection_name + ) + + logger.info("Attempting to exchange token on behalf of user") + new_token = await token_provider.aquire_token_on_behalf_of( + scopes=scopes, + user_assertion=token, + ) + return TokenResponse( + token=new_token, + scopes=scopes, # Expiration can be set based on the token provider's response + ) + + async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: + """Gets the first active flow state for the current context.""" + channel_id, user_id = self.__ids_from_context(context) + flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) + for auth_handler_id in self.__auth_handlers.keys(): + flow_state = await flow_storage_client.read(auth_handler_id) + if flow_state and flow_state.is_active(): + return flow_state + return None + + async def begin_or_continue_flow( + self, + context: TurnContext, + turn_state: TurnState, + auth_handler_id: str, + sec_route: bool = True, + ) -> FlowResponse: + """ + Begins or continues an OAuth flow. + + Args: + context: The context object for the current turn. + state: The state object for the current turn. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The token response from the OAuth provider. + """ + # robrandao: TODO -> is_started_from_route and sec_route + + async with self.open_flow(context, auth_handler_id) as flow: + flow_response: FlowResponse = await flow.begin_or_continue_flow(context) + + flow_state: FlowState = flow_response.flow_state + + if flow_state.tag == FlowStateTag.COMPLETE: + self.__sign_in_success_handler(context, turn_state, flow_state.auth_handler_id) + elif flow_state.tag == FlowStateTag.FAILURE: + self.__sign_in_failure_handler(context, turn_state, flow_state.auth_handler_id, flow_response.flow_error_tag) + + return flow_response + + def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: + """ + Resolves the auth handler to use based on the provided ID. + + Args: + auth_handler_id: Optional ID of the auth handler to resolve, defaults to first handler. + + Returns: + The resolved auth handler. + """ + if auth_handler_id: + if auth_handler_id not in self.__auth_handlers: + logger.error(f"Auth handler '{auth_handler_id}' not found") + raise ValueError(f"Auth handler '{auth_handler_id}' not found") + return self.__auth_handlers[auth_handler_id] + + # Return the first handler if no ID specified + return next(iter(self.__auth_handlers.values())) + + async def __sign_out( + self, + context: TurnContext, + auth_handler_ids: Iterable[str], + ) -> None: + """Signs out from the specified auth handlers. + + Args: + context: The context object for the current turn. + auth_handler_ids: List of auth handler IDs to sign out from. + + Deletes the associated flow states from storage. + """ + for auth_handler_id in auth_handler_ids: + flow, flow_storage_client, initial_flow_state = await self.__load_flow(context, auth_handler_id) + if initial_flow_state: + logger.info(f"Signing out from handler: {auth_handler_id}") + await flow.sign_out() + await flow_storage_client.delete(auth_handler_id) + + async def sign_out( + self, + context: TurnContext, + auth_handler_id: Optional[str] = None, + ) -> None: + """ + Signs out the current user. + This method clears the user's token and resets the OAuth state. + + Args: + context: The context object for the current turn. + auth_handler_id: Optional ID of the auth handler to use for sign out. If None, + signs out from all the handlers. + + Deletes the associated flow state(s) from storage. + """ + if auth_handler_id: + await self.__sign_out(context, [auth_handler_id]) + else: + await self.__sign_out(context, self.__auth_handlers.keys()) + + def on_sign_in_success( + self, + handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], + ) -> None: + """ + Sets a handler to be called when sign-in is successfully completed. + + Args: + handler: The handler function to call on successful sign-in. + """ + self.__sign_in_success_handler = handler + + def on_sign_in_failure( + self, + handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], + ) -> None: + """ + Sets a handler to be called when sign-in fails. + Args: + handler: The handler function to call on sign-in failure. + """ self.__sign_in_failure_handler = handler \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py index e2db50a2..c9730300 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py @@ -1,11 +1,10 @@ from .flow_state import ( FlowState, FlowStateTag, - FlowErrorTag, - FlowResponse + FlowErrorTag ) from .flow_storage_client import FlowStorageClient -from .oauth_flow import OAuthFlow +from .oauth_flow import OAuthFlow, FlowResponse __all__ = [ "FlowState", diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index de3a3c6e..dc580537 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -35,13 +35,15 @@ class FlowErrorTag(Enum): class FlowState(BaseModel, StoreItem): """Represents the state of an OAuthFlow""" - flow_id: str = "" # robrandao: TODO user_token: str = "" - expires_at: float = 0 + channel_id: str = "" user_id: str = "" ms_app_id: str = "" - abs_oauth_connection_name: Optional[str] = None + connection: str = "" + auth_handler_id: str = "" + + expires_at: float = 0 continuation_activity: Optional[Activity] = None attempts_remaining: int = 0 tag: FlowStateTag = FlowStateTag.NOT_STARTED @@ -52,7 +54,7 @@ def store_item_to_json(self) -> dict: @staticmethod def from_json_to_store_item(json_data: dict) -> "FlowState": return FlowState.model_validate(json_data) - + def is_expired(self) -> bool: return datetime.now().timestamp() >= self.expires_at diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py index 4a385cf0..4ca41432 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -38,22 +38,24 @@ def base_key(self) -> str: """Returns the prefix used for flow state storage isolation.""" return self.__base_key - def key(self, flow_id: str) -> str: + def key(self, auth_handler_id: str) -> str: """Creates a storage key for a specific sign-in handler.""" - return f"{self.__base_key}{flow_id}" + return f"{self.__base_key}{auth_handler_id}" - async def read(self, flow_id: str) -> Optional[FlowState]: + async def read(self, auth_handler_id: str) -> Optional[FlowState]: """Reads the flow state for a specific authentication handler.""" - key: str = self.key(flow_id) - data = await self.__storage.read([key], FlowState) - return FlowState.model_validate(data.get(key)) # robrandao: TODO -> verify contract + key: str = self.key(auth_handler_id) + data = await self.__storage.read([key], target_cls=FlowState) + if key not in data: + return None + return FlowState.model_validate(data.get(key)) async def write(self, value: FlowState) -> None: """Saves the flow state for a specific authentication handler.""" - key: str = self.key(value.flow_id) + key: str = self.key(value.auth_handler_id) await self.__storage.write({key: value}) - async def delete(self, flow_id: str) -> None: + async def delete(self, auth_handler_id: str) -> None: """Deletes the flow state for a specific authentication handler.""" - key: str = self.key(flow_id) + key: str = self.key(auth_handler_id) await self.__storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index de167c9a..37e904b5 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -59,17 +59,17 @@ def __init__( max_attempts: The maximum number of attempts for the flow set when starting a flow (default: 3). """ - if not self.flow_state or not user_token_client: + if not flow_state or not user_token_client: raise ValueError("OAuthFlow.__init__(): flow_state and user_token_client are required") - if (not flow_state.abs_oauth_connection_name or + if (not flow_state.connection or not flow_state.ms_app_id or not flow_state.channel_id or not flow_state.user_id): - raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, abs_oauth_connection_name defined") + raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined") self.__flow_state = flow_state.model_copy() - self.__abs_oauth_connection_name = self.__flow_state.abs_oauth_connection_name + self.__abs_oauth_connection_name = self.__flow_state.connection self.__ms_app_id = self.__flow_state.ms_app_id self.__channel_id = self.__flow_state.channel_id self.__user_id = self.__flow_state.user_id @@ -82,11 +82,6 @@ def __init__( @property def flow_state(self) -> FlowState: return self.__flow_state.model_copy() - - # async def __initialize_token_client(self, context: TurnContext) -> None: - # # robrandao: TODO is this safe - # # use cached value later - # self.__user_token_client = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) async def get_user_token(self, magic_code: str = None) -> TokenResponse: """Get the user token based on the context. @@ -162,7 +157,7 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: connection_name=self.__abs_oauth_connection_name, conversation=activity.get_conversation_reference(), relates_to=activity.relates_to, - ms_app_id=self.__ms_app_id # robrandao: TODO + ms_app_id=self.__ms_app_id ) sign_in_resource = await self.__user_token_client.agent_sign_in.get_sign_in_resource( diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index 67cbfa44..0bd5f91e 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -1,561 +1,552 @@ -import pytest - -import jwt - -from microsoft.agents.activity import ( - ActivityTypes, - TokenResponse -) -from microsoft.agents.hosting.core import MemoryStorage -from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline -from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase -from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase - -from microsoft.agents.hosting.core.app.oauth import Authorization -from microsoft.agents.hosting.core.oauth import ( - OAuthFlow, - FlowStorageClient, - FlowErrorTag, - FlowStateTag -) - -# test constants -from .tools.testing_oauth import * -from .tools.testing_authorization import ( - TestingConnectionManager, - create_test_auth_handler -) - -class TestUtils: - - def create_context(self, - mocker, - channel_id="__channel_id", - user_id="__user_id", - abs_oauth_connection_name="graph", - user_token_client=None): - - if not user_token_client: - user_token_client = self.create_mock_user_token_client(mocker) - - turn_context = mocker.Mock() - turn_context.activity.channel_id = channel_id - turn_context.activity.from_property.id = user_id - turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" - turn_context.turn_state = { - "__user_token_client": user_token_client - } - return turn_context - - def create_mock_user_token_client( - self, - mocker, - token=None, - ): - mock_user_token_client_class = mocker.Mock(spec=UserTokenClientBase) - mock_user_token_client_class.user_token = mocker.Mock(spec=UserTokenBase) - mock_user_token_client_class.user_token.get_token = mocker.AsyncMock( - return_value=TokenResponse(token=token) - ) - mock_user_token_client_class.user_token.sign_out = mocker.AsyncMock() - return mock_user_token_client_class - - @pytest.fixture - def mock_user_token_client_class(self, mocker): - return self.create_mock_user_token_client(mocker) - - @pytest.fixture - def mock_auth_flow_class(self, mocker): - mock_flow_class = mocker.Mock(spec=AuthFlow) - - mocker.patch.object(AuthFlow, "__init__", return_value=mock_flow_class) - mock_flow_class.get_user_token = mocker.AsyncMock() - mock_flow_class.sign_out = mocker.AsyncMock() - - return mock_flow_class - - @pytest.fixture - def turn_context(self, mocker): - return self.create_context(mocker, "__channel_id", "__user_id", "__connection") - - def create_user_token_client(self, mocker, get_token_return=None): - - user_token_client = mocker.Mock(spec=UserTokenClientBase) - user_token_client.user_token = mocker.Mock(spec=UserTokenBase) - user_token_client.user_token.get_token = mocker.AsyncMock() - user_token_client.user_token.sign_out = mocker.AsyncMock() - - return_value = TokenResponse() - if get_token_return: - return_value = TokenResponse(token=get_token_return) - user_token_client.user_token.get_token.return_value = return_value - - return user_token_client - - @pytest.fixture - def user_token_client(self, mocker): - return self.create_user_token_client(mocker, get_token_return=TEST_DEFAULTS.RES_TOKEN) - - @pytest.fixture - def auth_handlers(self): - handlers = {} - for key in STORAGE_INIT_DATA().keys(): - if key.startswith("auth/"): - auth_handler_name = key[key.rindex("/")+1:] - handlers[auth_handler_name] = create_test_auth_handler(auth_handler_name, True) - return handlers - - @pytest.fixture - def connection_manager(self): - return TestingConnectionManager() - - @pytest.fixture - def auth(self, connection_manager, storage, auth_handlers): - return Authorization(connection_manager, storage, auth_handlers) - -class TestAuthorizationUtils(TestUtils): - - def create_user_token_client(self, mocker, get_token_return=None): - - user_token_client = mocker.Mock(spec=UserTokenClientBase) - user_token_client.user_token = mocker.Mock(spec=UserTokenBase) - user_token_client.user_token.get_token = mocker.AsyncMock() - user_token_client.user_token.sign_out = mocker.AsyncMock() - - return_value = TokenResponse() - if get_token_return: - return_value = TokenResponse(token=get_token_return) - user_token_client.user_token.get_token.return_value = return_value - - return user_token_client - - def create_storage(self): - return MemoryStorage(STORAGE_INIT_DATA()) - - @pytest.fixture - def storage(self): - return self.create_storage() - - @pytest.fixture - def baseline_storage(self): - return StorageBaseline(STORAGE_INIT_DATA()) - - def patch_flow(self, mocker, flow_response=None, token=None,): - mocker.patch.object(AuthFlow, "get_user_token", return_value=TokenResponse(token=token)) - mocker.patch.object(AuthFlow, "sign_out") - mocker.patch.object(AuthFlow, "begin_or_continue_flow", return_value=flow_response) - -class TestAuthorization(TestAuthorizationUtils): - - def test_init_configuration_variants(self,storage, connection_manager, auth_handlers): - AGENTAPPLICATION = { - "USERAUTHORIZATION": { - "HANDLERS": { - handler_name: { - "SETTINGS": { - "title": handler.title, - "text": handler.text, - "abs_oauth_connection_name": handler.abs_oauth_connection_name, - "obo_connection_name": handler.obo_connection_name - } - } for handler_name, handler in auth_handlers.items() - } - } - } - auth_with_config_obj = Authorization( - storage, - connection_manager, - auth_handlers=None, - AGENTAPPLICATION=AGENTAPPLICATION - ) - auth_with_handlers_list = Authorization( - storage, - connection_manager, - auth_handlers=auth_handlers - ) - for auth_handler_name in auth_handlers.keys(): - auth_handler_a = auth_with_config_obj.resolve_handler(auth_handler_name) - auth_handler_b = auth_with_handlers_list.resolve_handler(auth_handler_name) - assert auth_handler_a == auth_handler_b - - @pytest.mark.asyncio - @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", - [ - ["", "webchat", "Alice"], - ["handler", "teams", "Bob"] - ]) - async def test_open_flow_value_error( - self, - mocker, - auth, - auth_handler_id, - channel_id, - user_id - ): - context = self.create_context(mocker, channel_id, user_id) - with pytest.raises(ValueError): - async with auth.open_flow(context, auth_handler_id): - pass - - @pytest.mark.asyncio - @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", - [ - ["", "webchat", "Alice"], - ["handler", "teams", "Bob"] - ]) - async def test_open_flow_readonly( - self, - mocker, - storage, - connection_client, - auth_handlers, - auth_handler_id, - channel_id, - user_id - ): - # setup - context = self.create_context(mocker, channel_id, user_id) - auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - assert actual_flow_state == expected_flow_state - - @pytest.mark.asyncio - async def test_open_flow_not_in_storage( - self, - mocker, - storage, - connection_manager, - auth_handlers - ): - # setup - context = self.create_context(mocker, "__channel_id", "__user_id") - auth = Authorization(storage, connection_manager, auth_handlers) - flow_storage_client = FlowStorageClient("__channel_id", "__user_id", storage) - - # test - async with auth.open_flow(context, "__auth_handler_id") as flow: - assert flow is not None - assert isinstance(flow, AuthFlow) - flow_state = await flow_storage_client.read("__auth_handler_id") - - # verify - assert flow_state.channel_id == "__channel_id" - assert flow_state.user_id == "__user_id" - assert flow_state.auth_handler_id == "__auth_handler_id" - assert flow_state.tag == FlowStateTag.NOT_STARTED - - @pytest.mark.asyncio - async def test_open_flow_success_modified_complete_flow( - self, - mocker, - storage, - connection_client, - auth_handlers, - auth_handler_id, - channel_id, - user_id - ): - # setup - channel_id = "teams" - user_id = "Alice" - auth_handler_id = "graph" - - self.create_user_token_client( - mocker, - get_token_return=TokenResponse(token=TEST_DEFAULTS.RES_TOKEN) - ) - - context = self.create_context(mocker, channel_id, user_id) - context.activity.type = ActivityTypes.message - context.activity.text = "123456" - - auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - expected_flow_state.tag = FlowStateTag.COMPLETE - expected_flow_state.user_token = TEST_DEFAULTS.RES_TOKEN - - flow_response = await flow.begin_or_continue_flow(context.activity) - res_flow_state = flow_response.flow_state - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now - - assert res_flow_state == expected_flow_state - assert actual_flow_state == expected_flow_state - - @pytest.mark.asyncio - async def test_open_flow_success_modified_failure( - self, - mocker, - baseline_storage, - storage, - connection_client, - auth_handlers, - auth_handler_id, - channel_id, - user_id - ): - # setup - channel_id = "webchat" - user_id = "Bob" - auth_handler_id = "graph" - - context = self.create_context(mocker, channel_id, user_id) - - auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - expected_flow_state.tag = FlowStateTag.FAILURE - expected_flow_state.attempts_remaining = 0 - - flow_response = await flow.begin_or_continue_flow(context.activity) - res_flow_state = flow_response.flow_state - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now - - assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT - assert res_flow_state == expected_flow_state - assert actual_flow_state == expected_flow_state - - baseline_storage.write(res_flow_state.model_copy()) - assert await baseline_storage.equals(storage) - - @pytest.mark.asyncio - async def test_open_flow_success_modified_signout( - self, - mocker, - storage, - connection_client, - auth_handlers, - auth_handler_id, - channel_id, - user_id - ): - # setup - channel_id = "webchat" - user_id = "Alice" - auth_handler_id = "graph" - - context = self.create_context(mocker, channel_id, user_id) - - auth = Authorization(storage, connection_client, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - expected_flow_state.tag = FlowStateTag.FAILURE - expected_flow_state.user_token = "" - - flow_response = await flow.sign_out() - res_flow_state = flow_response.flow_state - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now - - assert flow_response.flow_error_tag == FlowErrorTag.NONE - assert res_flow_state == expected_flow_state - assert actual_flow_state == expected_flow_state - - @pytest.mark.asyncio - async def test_get_token_success( - self, - mocker, - auth - ): - mock_user_token_client_class = self.create_user_token_client( - mocker, - get_token_return=TokenResponse(token="token") - ) - context = self.create_context(mocker, "__channel_id", "__user_id") - assert await auth.get_token(context, "auth_handler") == TokenResponse(token="token") - mock_user_token_client_class.get_user_token.called_once() - - @pytest.mark.asyncio - async def test_get_token_empty_response( - self, - mocker, - auth - ): - mock_user_token_client_class = self.create_user_token_client( - mocker, - get_token_return=TokenResponse() - ) - context = self.create_context(mocker, "__channel_id", "__user_id") - assert await auth.get_token(context, "auth_handler") == TokenResponse() - mock_user_token_client_class.get_user_token.called_once() - - @pytest.mark.asyncio - async def test_get_token_error( - self, - turn_context, - storage, - connection_manager, - auth_handlers - ): - auth = Authorization(storage, connection_manager, auth_handlers) - with pytest.raises(ValueError): - await auth.get_token(turn_context, "missing-handler") - - @pytest.mark.asyncio - async def test_exchange_token_no_token( - self, - turn_context, - mock_auth_flow_class, - mocker, - auth - ): - mock_auth_flow_class.get_user_token = mocker.AsyncMock( - return_value=TokenResponse() - ) - res = await auth.exchange_token(turn_context, ["scope"], "github") - assert res == TokenResponse() - - @pytest.mark.asyncio - async def test_exchange_token_not_exchangeable( - self, - mock_auth_flow_class, - turn_context, - mocker, - auth, - token - ): - token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - mock_auth_flow_class.get_user_token = mocker.AsyncMock( - return_value=TokenResponse(connection_name="github", token=token) - ) - res = await auth.exchange_token(turn_context, ["scope"], "github") - assert res == TokenResponse() - - @pytest.mark.asyncio - async def test_exchange_token_valid_exchangeable( - self, - mock_auth_flow_class, - turn_context, - mocker, - auth, - token - ): - token = jwt.encode({"aud": "valid://botframework.test.api"}, "") - mock_auth_flow_class.get_user_token = mocker.AsyncMock( - return_value=TokenResponse(connection_name="github", token=token) - ) - res = await auth.exchange_token(turn_context, ["scope"], "github") - assert res == TokenResponse(scopes=["scope"], token=token, connection_name="github") - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "channel_id, user_id, expected_flow_state", - [ - [] - ] - ) - async def test_get_active_flow_state(self, mocker, auth, channel_id, user_id, expected_flow_state): - context = self.create_context(mocker, channel_id, user_id) - actual_flow_state = await auth.get_active_flow_state(context) - assert actual_flow_state == expected_flow_state - - @pytest.mark.asyncio - async def test_get_active_flow_state_missing(self, mocker, auth): - context = self.create_context(mocker, "__channel_id", "__user_id") - res = await auth.get_active_flow_state(context) - assert res is None - - @pytest.mark.asyncio - async def begin_or_continue_flow( - self, - mocker, - turn_context, - storage, - baseline_storage, - connection_manager, - auth_handlers - ): - pass - - @pytest.mark.parametrize("auth_handler_id", ["handler", "connection"]) - def test_resolve_handler_specified(self, auth, auth_handlers, auth_handler_id): - assert auth.resolve_handler(auth_handler_id) == auth_handlers[auth_handler_id] - - def test_resolve_handler_error(self, auth): - with pytest.raises(ValueError): - auth.resolve_handler("missing-handler") - - def test_resolve_handler_first(self, auth, auth_handlers_list): - assert auth.resolve_handler() == auth_handlers_list[0] - - @pytest.mark.asyncio - async def test_sign_out_individual( - self, - mocker, - mock_user_token_client_class, - mock_auth_flow_class, - storage, - baseline_storage, - connection_manager, - auth_handlers - ): - # setup - storage_client = FlowStorageClient("teams", "Alice", storage) - context = self.create_context(mocker, "teams", "Alice") - - auth = Authorization(storage, connection_manager, auth_handlers) - await auth.sign_out(context, "graph") - - await baseline_storage.delete([storage_client.key("graph")]) - - # verify storage - assert await baseline_storage.equals(storage) - - # verify flow - mock_auth_flow_class.sign_out.assert_called_once_with("graph") - mock_user_token_client_class.user_token.sign_out.assert_called_once() - - @pytest.mark.asyncio - async def test_sign_out_all( - self, - mocker, - mock_user_token_client_class, - mock_auth_flow_class, - turn_context, - storage, - baseline_storage, - connection_manager, - auth_handlers - ): - # setup - storage_client = FlowStorageClient("webchat", "Alice", storage) - - auth = Authorization(storage, connection_manager, auth_handlers) - context = self.create_context(mocker, "webchat", "Alice") - await auth.sign_out(context) - - await baseline_storage.delete([storage_client.key("handler"), storage_client.key("connection")]) - - # verify storage - assert await baseline_storage.equals(storage) - - # verify flow - mock_auth_flow_class.sign_out.assert_called_once_with("handler") - mock_auth_flow_class.sign_out.assert_called_once_with("connection") - - - # robrandao: TODO -> handlers \ No newline at end of file +import pytest + +import jwt + +from microsoft.agents.activity import ( + ActivityTypes, + TokenResponse +) +from microsoft.agents.hosting.core import MemoryStorage +from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline +from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase +from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase + +from microsoft.agents.hosting.core.app.oauth import Authorization +from microsoft.agents.hosting.core.oauth import ( + FlowStorageClient, + FlowErrorTag, + FlowStateTag, + FlowResponse, + OAuthFlow +) + +# test constants +from .tools.testing_oauth import * +from .tools.testing_authorization import ( + TestingConnectionManager as MockConnectionManager, + create_test_auth_handler +) + +class TestUtils: + + def create_context(self, + mocker, + channel_id="__channel_id", + user_id="__user_id", + user_token_client=None): + + if not user_token_client: + user_token_client = self.create_mock_user_token_client(mocker) + + turn_context = mocker.Mock() + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = user_id + turn_context.activity.type = ActivityTypes.message + turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" + turn_context.adapter.AGENT_IDENTITY_KEY = "__agent_identity_key" + agent_identity = mocker.Mock() + agent_identity.claims = {"aud": MS_APP_ID} + turn_context.turn_state = { + "__user_token_client": user_token_client, + "__agent_identity_key": agent_identity, + } + return turn_context + + def create_mock_user_token_client( + self, + mocker, + token=RES_TOKEN, + ): + mock_user_token_client_class = mocker.Mock(spec=UserTokenClientBase) + mock_user_token_client_class.user_token = mocker.Mock(spec=UserTokenBase) + mock_user_token_client_class.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() if not token else TokenResponse(token=token) + ) + mock_user_token_client_class.user_token.sign_out = mocker.AsyncMock() + return mock_user_token_client_class + + @pytest.fixture + def mock_user_token_client_class(self, mocker): + return self.create_mock_user_token_client(mocker) + + def create_mock_oauth_flow_class(self, mocker, token_response): + mock_oauth_flow_class = mocker.Mock(spec=OAuthFlow) + # mock_oauth_flow_class.get_user_token = mocker.AsyncMock(return_value=token_response) + # mock_oauth_flow_class.sign_out = mocker.AsyncMock() + mocker.patch.object(OAuthFlow, "get_user_token", return_value=token_response) + mocker.patch.object(OAuthFlow, "sign_out") + return mock_oauth_flow_class + + @pytest.fixture + def mock_oauth_flow_class(self, mocker): + return self.create_mock_oauth_flow_class(mocker, TokenResponse(token=RES_TOKEN)) + # mock_flow_class = mocker.Mock(spec=OAuthFlow) + + # # mocker.patch.object(OAuthFlow, "__init__", return_value=mock_flow_class) + # mock_flow_class.get_user_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) + # mock_flow_class.sign_out = mocker.AsyncMock() + # mocker.patch.object(OAuthFlow, "get_user_token") + + # return mock_flow_class + + @pytest.fixture + def turn_context(self, mocker): + return self.create_context(mocker, "__channel_id", "__user_id", "__connection") + + def create_user_token_client(self, mocker, get_token_return=""): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if isinstance(get_token_return, TokenResponse): + return_value = get_token_return + elif get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + @pytest.fixture + def user_token_client(self, mocker): + return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) + + @pytest.fixture + def auth_handlers(self): + handlers = {} + for key in STORAGE_INIT_DATA().keys(): + if key.startswith("auth/"): + auth_handler_name = key[key.rindex("/")+1:] + handlers[auth_handler_name] = create_test_auth_handler(auth_handler_name, True) + return handlers + + @pytest.fixture + def connection_manager(self): + return MockConnectionManager() + + @pytest.fixture + def auth(self, connection_manager, storage, auth_handlers): + return Authorization(storage, connection_manager, auth_handlers) + +class TestAuthorizationUtils(TestUtils): + + def create_storage(self): + return MemoryStorage(STORAGE_INIT_DATA()) + + @pytest.fixture + def storage(self): + return self.create_storage() + + @pytest.fixture + def baseline_storage(self): + return StorageBaseline(STORAGE_INIT_DATA()) + + def patch_flow(self, mocker, flow_response=None, token=None,): + mocker.patch.object(OAuthFlow, "get_user_token", return_value=TokenResponse(token=token)) + mocker.patch.object(OAuthFlow, "sign_out") + mocker.patch.object(OAuthFlow, "begin_or_continue_flow", return_value=flow_response) + +class TestAuthorization(TestAuthorizationUtils): + + def test_init_configuration_variants(self,storage, connection_manager, auth_handlers): + """Test initialization of authorization with different configuration variants.""" + AGENTAPPLICATION = { + "USERAUTHORIZATION": { + "HANDLERS": { + handler_name: { + "SETTINGS": { + "title": handler.title, + "text": handler.text, + "abs_oauth_connection_name": handler.abs_oauth_connection_name, + "obo_connection_name": handler.obo_connection_name + } + } for handler_name, handler in auth_handlers.items() + } + } + } + auth_with_config_obj = Authorization( + storage, + connection_manager, + auth_handlers=None, + AGENTAPPLICATION=AGENTAPPLICATION + ) + auth_with_handlers_list = Authorization( + storage, + connection_manager, + auth_handlers=auth_handlers + ) + for auth_handler_name in auth_handlers.keys(): + auth_handler_a = auth_with_config_obj.resolve_handler(auth_handler_name) + auth_handler_b = auth_with_handlers_list.resolve_handler(auth_handler_name) + + assert auth_handler_a.name == auth_handler_b.name + assert auth_handler_a.title == auth_handler_b.title + assert auth_handler_a.text == auth_handler_b.text + assert auth_handler_a.abs_oauth_connection_name == auth_handler_b.abs_oauth_connection_name + assert auth_handler_a.obo_connection_name == auth_handler_b.obo_connection_name + + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", + [ + ["missing", "webchat", "Alice"], + ["handler", "teams", "Bob"] + ]) + async def test_open_flow_value_error( + self, + mocker, + auth, + auth_handler_id, + channel_id, + user_id + ): + """Test opening a flow with a missing auth handler.""" + context = self.create_context(mocker, channel_id, user_id) + with pytest.raises(ValueError): + async with auth.open_flow(context, auth_handler_id): + pass + + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", + [ + ["", "webchat", "Alice"], + ["graph", "teams", "Bob"], + ["slack", "webchat", "Chuck"] + ]) + async def test_open_flow_readonly( + self, + mocker, + storage, + connection_manager, + auth_handlers, + auth_handler_id, + channel_id, + user_id + ): + """Test opening a flow and not modifying it.""" + # setup + context = self.create_context(mocker, channel_id, user_id) + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth.resolve_handler(auth_handler_id).name) + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_complete_flow( + self, + mocker, + storage, + connection_manager, + mock_user_token_client_class, + auth_handlers + ): + # setup + channel_id = "teams" + user_id = "Alice" + auth_handler_id = "graph" + + self.create_user_token_client( + mocker, + get_token_return=RES_TOKEN + ) + + context = self.create_context(mocker, channel_id, user_id) + context.activity.type = ActivityTypes.message + context.activity.text = "123456" + + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.user_token = RES_TOKEN + + flow_response = await flow.begin_or_continue_flow(context.activity) + res_flow_state = flow_response.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expires_at = res_flow_state.expires_at # we won't check this for now + + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_failure( + self, + mocker, + storage, + connection_manager, + auth_handlers, + ): + # setup + channel_id = "teams" + user_id = "Bob" + auth_handler_id = "slack" + + context = self.create_context(mocker, channel_id, user_id) + context.activity.text = "invalid_magic_code" + + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + expected_flow_state.attempts_remaining = 0 + + flow_response = await flow.begin_or_continue_flow(context.activity) + res_flow_state = flow_response.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + + assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_signout( + self, + mocker, + storage, + connection_manager, + auth_handlers + ): + # setup + channel_id = "webchat" + user_id = "Alice" + auth_handler_id = "graph" + + context = self.create_context(mocker, channel_id, user_id) + + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.NOT_STARTED + expected_flow_state.user_token = "" + + await flow.sign_out() + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_get_token_success( + self, + mocker, + auth + ): + mock_user_token_client_class = self.create_user_token_client( + mocker, + get_token_return=TokenResponse(token="token") + ) + context = self.create_context(mocker, "__channel_id", "__user_id", user_token_client=mock_user_token_client_class) + assert await auth.get_token(context, "slack") == TokenResponse(token="token") + mock_user_token_client_class.user_token.get_token.assert_called_once() + + @pytest.mark.asyncio + async def test_get_token_empty_response( + self, + mocker, + auth + ): + mock_user_token_client_class = self.create_user_token_client( + mocker, + get_token_return=TokenResponse() + ) + context = self.create_context(mocker, "__channel_id", "__user_id", user_token_client=mock_user_token_client_class) + assert await auth.get_token(context, "graph") == TokenResponse() + mock_user_token_client_class.user_token.get_token.assert_called_once() + + @pytest.mark.asyncio + async def test_get_token_error( + self, + turn_context, + storage, + connection_manager, + auth_handlers + ): + auth = Authorization(storage, connection_manager, auth_handlers) + with pytest.raises(ValueError): + await auth.get_token(turn_context, "missing-handler") + + @pytest.mark.asyncio + async def test_exchange_token_no_token( + self, + mocker, + turn_context, + auth + ): + self.create_mock_oauth_flow_class(mocker, TokenResponse()) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse() + + @pytest.mark.asyncio + async def test_exchange_token_not_exchangeable( + self, + mocker, + turn_context, + auth + ): + token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") + self.create_mock_oauth_flow_class(mocker, TokenResponse(connection_name="github", token=token)) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse() + + @pytest.mark.asyncio + async def test_exchange_token_valid_exchangeable( + self, + turn_context, + mocker, + auth + ): + token = jwt.encode({"aud": "api://botframework.test.api"}, "") + self.create_mock_oauth_flow_class(mocker, TokenResponse(connection_name="github", token=token)) + mock_user_token_client_class = self.create_mock_user_token_client(mocker, token=token) + mock_user_token_client_class.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(scopes=["scope"], token=token, connection_name="github")) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse(token="github-obo-connection-obo-token") + + @pytest.mark.asyncio + async def test_get_active_flow_state(self, mocker, auth): + context = self.create_context(mocker, "webchat", "Alice") + actual_flow_state = await auth.get_active_flow_state(context) + assert actual_flow_state == STORAGE_SAMPLE_DICT[flow_key("webchat", "Alice", "github")] + + @pytest.mark.asyncio + async def test_get_active_flow_state_missing(self, mocker, auth): + context = self.create_context(mocker, "__channel_id", "__user_id") + res = await auth.get_active_flow_state(context) + assert res is None + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_success( + self, + mocker, + auth + ): + # robrandao: TODO -> lower priority -> more testing here + # setup + mocker.patch.object(OAuthFlow, "begin_or_continue_flow", return_value=FlowResponse( + token_response=TokenResponse(token="token"), + flow_state=FlowState(tag=FlowStateTag.COMPLETE, auth_handler_id="github") + )) + context = self.create_context(mocker, "webchat", "Alice") + + context.dummy_val = None + def on_sign_in_success(context, turn_state, auth_handler_id): + context.dummy_val = auth_handler_id + def on_sign_in_failure(context, turn_state, auth_handler_id, err): + context.dummy_val = str(err) + + # test + auth.on_sign_in_success(on_sign_in_success) + auth.on_sign_in_failure(on_sign_in_failure) + flow_response = await auth.begin_or_continue_flow(context, None, "github") + assert context.dummy_val == "github" + assert flow_response.token_response == TokenResponse(token="token") + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_failure( + self, + mocker, + mock_oauth_flow_class, + auth + ): + # robrandao: TODO -> lower priority -> more testing here + # setup + mocker.patch.object(OAuthFlow, "begin_or_continue_flow", return_value=FlowResponse( + token_response=TokenResponse(token="token"), + flow_state=FlowState(tag=FlowStateTag.FAILURE, auth_handler_id="github"), + flow_state_error=FlowErrorTag.MAGIC_FORMAT + )) + context = self.create_context(mocker, "webchat", "Alice") + + context.dummy_val = None + def on_sign_in_success(context, turn_state, auth_handler_id): + context.dummy_val = auth_handler_id + def on_sign_in_failure(context, turn_state, auth_handler_id, err): + context.dummy_val = str(err) + + # test + auth.on_sign_in_success(on_sign_in_success) + auth.on_sign_in_failure(on_sign_in_failure) + flow_response = await auth.begin_or_continue_flow(context, None, "github") + assert context.dummy_val == "FlowErrorTag.NONE" + assert flow_response.token_response == TokenResponse(token="token") + + @pytest.mark.parametrize("auth_handler_id", ["graph", "github"]) + def test_resolve_handler_specified(self, auth, auth_handlers, auth_handler_id): + assert auth.resolve_handler(auth_handler_id) == auth_handlers[auth_handler_id] + + def test_resolve_handler_error(self, auth): + with pytest.raises(ValueError): + auth.resolve_handler("missing-handler") + + def test_resolve_handler_first(self, auth, auth_handlers): + assert auth.resolve_handler() == next(iter(auth_handlers.values())) + + @pytest.mark.asyncio + async def test_sign_out_individual( + self, + mocker, + mock_user_token_client_class, + mock_oauth_flow_class, + storage, + baseline_storage, + connection_manager, + auth_handlers + ): + # setup + storage_client = FlowStorageClient("teams", "Alice", storage) + context = self.create_context(mocker, "teams", "Alice") + auth = Authorization(storage, connection_manager, auth_handlers) + + # test + await auth.sign_out(context, "graph") + + # verify + assert await storage.read([storage_client.key("graph")], target_cls=FlowState) == {} + OAuthFlow.sign_out.assert_called_once() + + @pytest.mark.asyncio + async def test_sign_out_all( + self, + mocker, + mock_user_token_client_class, + mock_oauth_flow_class, + turn_context, + storage, + baseline_storage, + connection_manager, + auth_handlers + ): + # setup + storage_client = FlowStorageClient("webchat", "Alice", storage) + + auth = Authorization(storage, connection_manager, auth_handlers) + context = self.create_context(mocker, "webchat", "Alice") + await auth.sign_out(context) + + # verify + assert await storage.read([storage_client.key("graph")], target_cls=FlowState) == {} + assert await storage.read([storage_client.key("github")], target_cls=FlowState) == {} + assert await storage.read([storage_client.key("slack")], target_cls=FlowState) == {} + OAuthFlow.sign_out.assert_called() # ignore red squiggly -> mocked diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py index dabbcc69..0721b63f 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py @@ -1,110 +1,69 @@ -from datetime import datetime - -import pytest - -from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag - -class TestFlowState: - - def test_refresh_to_failure_expired(self): - """Test that the flow state refreshes to failure when expired.""" - flow_state = FlowState( - tag=FlowStateTag.CONTINUE, - attempts_remaining=1, - expires_at=datetime.now().timestamp() - ) - flow_state.refresh() - assert flow_state.tag == FlowStateTag.FAILURE - - def test_refresh_to_failure_max_attempts(self): - """Test that the flow state refreshes to failure when max attempts reached.""" - flow_state = FlowState( - tag=FlowStateTag.CONTINUE, - attempts_remaining=0, - ) - flow_state.refresh() - assert flow_state.tag == FlowStateTag.FAILURE - - def test_refresh_unchanged_continue(self): - """Test that the flow state remains unchanged when refreshed with a valid CONTINUE state""" - flow_state = FlowState( - tag=FlowStateTag.CONTINUE, - attempts_remaining=1, - expires_at=datetime.now().timestamp() + 10000 - ) - prev_tag = flow_state.tag - flow_state.refresh() - assert flow_state.tag == prev_tag - - def test_refresh_unchanged_begin(self): - """Test that the flow state remains unchanged when refreshed with a valid BEGIN state""" - flow_state = FlowState( - tag=FlowStateTag.BEGIN, - attempts_remaining=10, - expires_at=datetime.now().timestamp() + 30000 - ) - prev_tag = flow_state.tag - flow_state.refresh() - assert flow_state.tag == prev_tag - - @pytest.mark.parametrize( - "flow_state, expected", - [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), - True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), - False), - ] - ) - def test_is_expired(self, flow_state, expected): - assert flow_state.is_expired() == expected - - @pytest.mark.parametrize( - "flow_state, expected", - [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), - False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), - True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()), - True), - ] - ) - def test_reached_max_attempts(self, flow_state, expected): - assert flow_state.reached_max_attempts() == expected - - @pytest.mark.parametrize( - "flow_state, expected", - [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), - False), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), - False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), - False), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=2, expires_at=datetime.now().timestamp()+1000), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=0, expires_at=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), - True) - ] - ) - def test_is_active(self, flow_state, expected): +from datetime import datetime + +import pytest + +from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag + +class TestFlowState: + + @pytest.mark.parametrize( + "flow_state, expected", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + True), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), + False), + ] + ) + def test_is_expired(self, flow_state, expected): + assert flow_state.is_expired() == expected + + @pytest.mark.parametrize( + "flow_state, expected", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + False), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + True), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()), + True), + ] + ) + def test_reached_max_attempts(self, flow_state, expected): + assert flow_state.reached_max_attempts() == expected + + @pytest.mark.parametrize( + "flow_state, expected", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + False), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + False), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), + False), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=2, expires_at=datetime.now().timestamp()+1000), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=0, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + True) + ] + ) + def test_is_active(self, flow_state, expected): assert flow_state.is_active() == expected \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py index f0171fb6..925d88eb 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py @@ -1,146 +1,155 @@ -import pytest - -from microsoft.agents.hosting.core.storage import MemoryStorage -from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem -from microsoft.agents.hosting.core.oauth import FlowState, FlowStorageClient - -class TestFlowStorageClient: - - @pytest.fixture - def channel_id(self): - return "__channel_id" - - @pytest.fixture - def user_id(self): - return "__user_id" - - @pytest.fixture - def storage(self): - return MemoryStorage() - - @pytest.fixture - def client(self, channel_id, user_id, storage): - return FlowStorageClient(channel_id, user_id, storage) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "channel_id, from_property_id", - [ - ("channel_id", "from_property_id"), - ("teams_id", "Bob"), - ("channel", "Alice"), - ], - ) - async def test_init_base_key(self, mocker, channel_id, user_id): - client = FlowStorageClient(channel_id, user_id, mocker.Mock()) - assert client.base_key == f"auth/{channel_id}/{user_id}/" - - @pytest.mark.asyncio - async def test_init_fails_without_user_id(self, channel_id, storage): - with pytest.raises(ValueError): - FlowStorageClient(channel_id, "", storage) - - @pytest.mark.asyncio - async def test_init_fails_without_channel_id(self, user_id, storage): - with pytest.raises(ValueError): - FlowStorageClient("", user_id, storage) - - @pytest.mark.parametrize( - "auth_handler_id, expected", - [ - ("handler", "auth/__channel_id/__user_id/handler"), - ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), - ] - ) - def test_key(self, client, auth_handler_id, expected): - assert client.key(auth_handler_id) == expected - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", ["handler", "auth_handler"] - ) - async def test_read(self, mocker, user_id, channel_id, auth_handler_id): - storage = mocker.AsyncMock() - key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" - storage.read.return_value = {key: FlowState()} - client = FlowStorageClient(channel_id, user_id, storage) - res = await client.read(auth_handler_id) - assert res is storage.read.return_value[key] - storage.read.assert_called_once_with([client.key(auth_handler_id)], FlowState) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", ["handler", "auth_handler"] - ) - async def test_write(self, mocker, channel_id, user_id, auth_handler_id): - storage = mocker.AsyncMock() - storage.write.return_value = None - client = FlowStorageClient(channel_id, user_id, storage) - flow_state = mocker.Mock(spec=FlowState) - flow_state.flow_id = auth_handler_id - await client.write(flow_state) - storage.write.assert_called_once_with({ client.key(auth_handler_id): flow_state }) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", ["handler", "auth_handler"] - ) - async def test_delete(self, mocker, channel_id, user_id, auth_handler_id): - storage = mocker.AsyncMock() - storage.delete.return_value = None - client = FlowStorageClient(channel_id, user_id, storage) - await client.delete(auth_handler_id) - storage.delete.assert_called_once_with([client.key(auth_handler_id)]) - - @pytest.mark.asyncio - async def test_integration_with_memory_storage(self, channel_id, user_id): - - flow_state_alpha = FlowState(flow_id="handler", flow_started=True) - flow_state_beta = FlowState(flow_id="auth_handler", flow_started=True, user_token="token") - - storage = MemoryStorage({ - "some_data": MockStoreItem({"value": "test"}), - f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, - f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta, - }) - baseline = MemoryStorage({ - "some_data": MockStoreItem({"value": "test"}), - f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, - "fauth/{channel_id}/{user_id}/auth_handler": flow_state_beta, - }) - - # helpers - async def read_check(*args, **kwargs): - res_storage = await storage.read(*args, **kwargs) - res_baseline = await baseline.read(*args, **kwargs) - assert res_storage == res_baseline - - async def write_both(*args, **kwargs): - await storage.write(*args, **kwargs) - await baseline.write(*args, **kwargs) - - async def delete_both(*args, **kwargs): - await storage.delete(*args, **kwargs) - await baseline.delete(*args, **kwargs) - - client = FlowStorageClient(channel_id, user_id, storage) - - new_flow_state_alpha = FlowState(flow_id="handler") - flow_state_chi = FlowState(flow_id="chi") - - await client.write(new_flow_state_alpha) - await client.write(flow_state_chi) - await baseline.write({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) - await baseline.write({f"auth/{channel_id}/{user_id}/chi": flow_state_chi.model_copy()}) - - await write_both({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) - await write_both({f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta.model_copy()}) - await write_both({"other_data": MockStoreItem({"value": "more"})}) - - await delete_both(["some_data"]) - - await read_check([f"auth/{channel_id}/{user_id}/handler"], target_cls=FlowState) - await read_check([f"auth/{channel_id}/{user_id}/auth_handler"], target_cls=FlowState) - await read_check([f"auth/{channel_id}/{user_id}/chi"], target_cls=FlowState) - await read_check(["other_data"], target_cls=MockStoreItem) - await read_check(["some_data"], target_cls=MockStoreItem) +import pytest + +from microsoft.agents.hosting.core.storage import MemoryStorage +from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem +from microsoft.agents.hosting.core.oauth import FlowState, FlowStorageClient + +class TestFlowStorageClient: + + @pytest.fixture + def channel_id(self): + return "__channel_id" + + @pytest.fixture + def user_id(self): + return "__user_id" + + @pytest.fixture + def storage(self): + return MemoryStorage() + + @pytest.fixture + def client(self, channel_id, user_id, storage): + return FlowStorageClient(channel_id, user_id, storage) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "channel_id, user_id", + [ + ("channel_id", "user_id"), + ("teams_id", "Bob"), + ("channel", "Alice"), + ], + ) + async def test_init_base_key(self, mocker, channel_id, user_id): + client = FlowStorageClient(channel_id, user_id, mocker.Mock()) + assert client.base_key == f"auth/{channel_id}/{user_id}/" + + @pytest.mark.asyncio + async def test_init_fails_without_user_id(self, channel_id, storage): + with pytest.raises(ValueError): + FlowStorageClient(channel_id, "", storage) + + @pytest.mark.asyncio + async def test_init_fails_without_channel_id(self, user_id, storage): + with pytest.raises(ValueError): + FlowStorageClient("", user_id, storage) + + @pytest.mark.parametrize( + "auth_handler_id, expected", + [ + ("handler", "auth/__channel_id/__user_id/handler"), + ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ] + ) + def test_key(self, client, auth_handler_id, expected): + assert client.key(auth_handler_id) == expected + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id", ["handler", "auth_handler"] + ) + async def test_read(self, mocker, user_id, channel_id, auth_handler_id): + storage = mocker.AsyncMock() + key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" + storage.read.return_value = {key: FlowState()} + client = FlowStorageClient(channel_id, user_id, storage) + res = await client.read(auth_handler_id) + assert res is storage.read.return_value[key] + storage.read.assert_called_once_with([client.key(auth_handler_id)], target_cls=FlowState) + + @pytest.mark.asyncio + async def test_read_missing(self, mocker): + storage = mocker.AsyncMock() + storage.read.return_value = {} + client = FlowStorageClient("__channel_id", "__user_id", storage) + res = await client.read("non_existent_handler") + assert res is None + storage.read.assert_called_once_with([client.key("non_existent_handler")], target_cls=FlowState) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id", ["handler", "auth_handler"] + ) + async def test_write(self, mocker, channel_id, user_id, auth_handler_id): + storage = mocker.AsyncMock() + storage.write.return_value = None + client = FlowStorageClient(channel_id, user_id, storage) + flow_state = mocker.Mock(spec=FlowState) + flow_state.auth_handler_id = auth_handler_id + await client.write(flow_state) + storage.write.assert_called_once_with({ client.key(auth_handler_id): flow_state }) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id", ["handler", "auth_handler"] + ) + async def test_delete(self, mocker, channel_id, user_id, auth_handler_id): + storage = mocker.AsyncMock() + storage.delete.return_value = None + client = FlowStorageClient(channel_id, user_id, storage) + await client.delete(auth_handler_id) + storage.delete.assert_called_once_with([client.key(auth_handler_id)]) + + @pytest.mark.asyncio + async def test_integration_with_memory_storage(self, channel_id, user_id): + + flow_state_alpha = FlowState(auth_handler_id="handler", flow_started=True) + flow_state_beta = FlowState(auth_handler_id="auth_handler", flow_started=True, user_token="token") + + storage = MemoryStorage({ + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + }) + baseline = MemoryStorage({ + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"fauth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + }) + + # helpers + async def read_check(*args, **kwargs): + res_storage = await storage.read(*args, **kwargs) + res_baseline = await baseline.read(*args, **kwargs) + assert res_storage == res_baseline + + async def write_both(*args, **kwargs): + await storage.write(*args, **kwargs) + await baseline.write(*args, **kwargs) + + async def delete_both(*args, **kwargs): + await storage.delete(*args, **kwargs) + await baseline.delete(*args, **kwargs) + + client = FlowStorageClient(channel_id, user_id, storage) + + new_flow_state_alpha = FlowState(auth_handler_id="handler") + flow_state_chi = FlowState(auth_handler_id="chi") + + await client.write(new_flow_state_alpha) + await client.write(flow_state_chi) + await baseline.write({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) + await baseline.write({f"auth/{channel_id}/{user_id}/chi": flow_state_chi.model_copy()}) + + await write_both({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) + await write_both({f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta.model_copy()}) + await write_both({"other_data": MockStoreItem({"value": "more"})}) + + await delete_both(["some_data"]) + + await read_check([f"auth/{channel_id}/{user_id}/handler"], target_cls=FlowState) + await read_check([f"auth/{channel_id}/{user_id}/auth_handler"], target_cls=FlowState) + await read_check([f"auth/{channel_id}/{user_id}/chi"], target_cls=FlowState) + await read_check(["other_data"], target_cls=MockStoreItem) + await read_check(["some_data"], target_cls=MockStoreItem) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py similarity index 95% rename from libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py rename to libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index ba46be2c..0a33c28b 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_auth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -1,395 +1,393 @@ -import pytest - -from microsoft.agents.activity import ( - Activity, - ActivityTypes, - TokenResponse, - SignInResource, - TokenExchangeState, - ConversationReference, - ChannelAccount, -) -from microsoft.agents.hosting.core.oauth import ( - OAuthFlow, - FlowErrorTag, - FlowStateTag -) -from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase -from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase - -# test constants -from .tools.testing_oauth import * - -class TestOAuthFlowUtils: - - def create_user_token_client(self, mocker, get_token_return=None): - - user_token_client = mocker.Mock(spec=UserTokenClientBase) - user_token_client.user_token = mocker.Mock(spec=UserTokenBase) - user_token_client.user_token.get_token = mocker.AsyncMock() - user_token_client.user_token.sign_out = mocker.AsyncMock() - - return_value = TokenResponse() - if get_token_return: - return_value = TokenResponse(token=get_token_return) - user_token_client.user_token.get_token.return_value = return_value - - return user_token_client - - @pytest.fixture - def user_token_client(self, mocker): - return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) - - def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): - # def conv_ref(): - # return mocker.MagicMock(spec=ConversationReference) - mock_conversation_ref = mocker.MagicMock(ConversationReference) - mocker.patch.object(Activity, "get_conversation_reference", return_value=mocker.MagicMock(ConversationReference)) - # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) - return Activity( - type=activity_type, - name=name, - from_property=ChannelAccount(id=USER_ID), - channel_id=CHANNEL_ID, - # get_conversation_reference=mocker.Mock(return_value=conv_ref), - relates_to=mocker.MagicMock(ConversationReference), - value=value, - text=text - ) - - @pytest.fixture(params=FLOW_STATES.ALL()) - def sample_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture(params=FLOW_STATES.FAILED()) - def sample_failed_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture(params=FLOW_STATES.INACTIVE()) - def sample_inactive_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture(params=FLOW_STATES.ACTIVE()) - def sample_active_flow_state(self, request): - return request.param.model_copy() - - @pytest.fixture - def flow(self, sample_flow_state, user_token_client): - return OAuthFlow(sample_flow_state, user_token_client) - - -class TestOAuthFlow(TestOAuthFlowUtils): - - def test_init_no_user_token_client(self, sample_flow_state): - with pytest.raises(ValueError): - OAuthFlow(sample_flow_state, None) - - @pytest.mark.parametrize("missing_value", [ - "abs_oauth_connection_name", - "ms_app_id", - "channel_id", - "user_id" - ]) - def test_init_errors(self, missing_value, user_token_client): - flow_state = FLOW_STATES.STARTED_FLOW.model_copy() - flow_state.__setattr__(missing_value, None) - with pytest.raises(ValueError): - OAuthFlow(flow_state, user_token_client) - flow_state.__setattr__(missing_value, "") - with pytest.raises(ValueError): - OAuthFlow(flow_state, user_token_client) - - def test_init_with_state(self, sample_flow_state, user_token_client): - flow = OAuthFlow(sample_flow_state, user_token_client) - assert flow.flow_state == sample_flow_state - - def test_flow_state_prop_copy(self, flow): - flow_state = flow.flow_state - flow_state.user_id = (flow_state.user_id + "_copy") - assert flow.flow_state.user_id == USER_ID - assert flow_state.user_id == f"{USER_ID}_copy" - - @pytest.mark.asyncio - async def test_get_user_token_success(self, sample_flow_state, user_token_client): - # setup - flow = OAuthFlow(sample_flow_state, user_token_client) - expected_final_flow_state = sample_flow_state - expected_final_flow_state.user_token = RES_TOKEN - - # test - token_response = await flow.get_user_token() - token = token_response.token - - # verify - assert token == RES_TOKEN - assert flow.flow_state == expected_final_flow_state - user_token_client.user_token.get_token.assert_called_once_with( - user_id=USER_ID, - connection_name=ABS_OAUTH_CONNECTION_NAME, - channel_id=CHANNEL_ID, - magic_code=None - ) - - @pytest.mark.asyncio - async def test_get_user_token_failure(self, mocker, sample_flow_state): - # setup - user_token_client = self.create_user_token_client(mocker, get_token_return=None) - flow = OAuthFlow(sample_flow_state, user_token_client) - expected_final_flow_state = flow.flow_state # robrandao: TODO -> what happens if fails and has user_token? - - # test - token_response = await flow.get_user_token() - - # verify - assert token_response == TokenResponse() - assert flow.flow_state == expected_final_flow_state - user_token_client.user_token.get_token.assert_called_once_with( - user_id=USER_ID, - connection_name=ABS_OAUTH_CONNECTION_NAME, - channel_id=CHANNEL_ID, - magic_code=None - ) - - @pytest.mark.asyncio - async def test_sign_out(self, sample_flow_state, user_token_client): - # setup - flow = OAuthFlow(sample_flow_state, user_token_client) - expected_flow_state = sample_flow_state - expected_flow_state.user_token = "" - expected_flow_state.tag = FlowStateTag.NOT_STARTED - - # test - await flow.sign_out() - - # verify - user_token_client.user_token.sign_out.assert_called_once_with( - user_id=USER_ID, - connection_name=ABS_OAUTH_CONNECTION_NAME, - channel_id=CHANNEL_ID - ) - assert flow.flow_state == expected_flow_state - - @pytest.mark.asyncio - async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_client): - # setup - flow = OAuthFlow(sample_flow_state, user_token_client) - activity = mocker.Mock(spec=Activity) - expected_flow_state = sample_flow_state - expected_flow_state.user_token = RES_TOKEN - - # test - response = await flow.begin_flow(activity) - - # verify - flow_state = flow.flow_state - assert flow_state == expected_flow_state - # assert flow_state.flow_started is False # robrandao: TODO? - - assert response.flow_state == flow_state - assert response.sign_in_resource is None # No sign-in resource in this case - assert response.flow_error_tag == FlowErrorTag.NONE - assert response.token_response.token == RES_TOKEN - user_token_client.user_token.get_token.assert_called_once_with( - user_id=USER_ID, - connection_name=ABS_OAUTH_CONNECTION_NAME, - channel_id=CHANNEL_ID, - # magic_code=None is an implementation detail, and ideally - # shouldn't be part of the test - magic_code=None - ) - - @pytest.mark.asyncio - async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_client): - # mock - # tes = mocker.Mock(TokenExchangeState) - # tes.get_encoded_state = mocker.Mock(return_value="encoded_state") - mocker.patch.object(TokenExchangeState, "get_encoded_state", return_value="encoded_state") - dummy_sign_in_resource = SignInResource( - sign_in_link="https://example.com/signin", - token_exchange_state=mocker.Mock( - TokenExchangeState, get_encoded_state=mocker.Mock(return_value="encoded_state") - ) - ) - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock( - return_value=dummy_sign_in_resource) - activity = self.create_activity(mocker) - - # setup - flow = OAuthFlow(sample_flow_state, user_token_client) - expected_flow_state = sample_flow_state - expected_flow_state.user_token = "" - expected_flow_state.tag = FlowStateTag.BEGIN - expected_flow_state.attempts_remaining = 3 - - # test - response = await flow.begin_flow(activity) - - # verify flow_state - flow_state = flow.flow_state - expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now - assert flow_state == response.flow_state - assert flow_state == expected_flow_state - - # verify FlowResponse - assert response.sign_in_resource == dummy_sign_in_resource - assert response.flow_error_tag == FlowErrorTag.NONE - assert not response.token_response - # robrandao: TODO more assertions on sign_in_resource - - @pytest.mark.asyncio - async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state, user_token_client): - # setup - activity = mocker.Mock() - flow = OAuthFlow(sample_inactive_flow_state, user_token_client) - expected_flow_state = sample_inactive_flow_state - expected_flow_state.tag = FlowStateTag.FAILURE - - # test - flow_response = await flow.continue_flow(activity) - flow_state = flow.flow_state - - # verify - # robrandao: TODO -> revise - assert flow_state == expected_flow_state - assert flow_response.flow_state == flow_state - assert not flow_response.token_response - - async def helper_continue_flow_failure(self, active_flow_state, user_token_client, activity, flow_error_tag): - # setup - flow = OAuthFlow(active_flow_state, user_token_client) - expected_flow_state = active_flow_state - expected_flow_state.tag = FlowStateTag.CONTINUE if active_flow_state.attempts_remaining > 1 else FlowStateTag.FAILURE - expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining - 1 - - # test - flow_response = await flow.continue_flow(activity) - flow_state = flow.flow_state - - # verify - assert flow_response.flow_state == flow_state - assert expected_flow_state == flow_state - assert flow_response.token_response == TokenResponse() - assert flow_response.flow_error_tag == flow_error_tag - - async def helper_continue_flow_success(self, active_flow_state, user_token_client, activity): - # setup - flow = OAuthFlow(active_flow_state, user_token_client) - expected_flow_state = active_flow_state - expected_flow_state.tag = FlowStateTag.COMPLETE - expected_flow_state.user_token = RES_TOKEN - expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining - - # test - flow_response = await flow.continue_flow(activity) - flow_state = flow.flow_state - expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now - - # verify - assert flow_response.flow_state == flow_state - assert expected_flow_state == flow_state - assert flow_response.token_response == TokenResponse(token=RES_TOKEN) - assert flow_response.flow_error_tag == FlowErrorTag.NONE - - @pytest.mark.asyncio - @pytest.mark.parametrize("magic_code", ["magic", "123", "", "1239453"]) - async def test_continue_flow_active_message_magic_format_error(self, mocker, sample_active_flow_state, user_token_client, magic_code): - # setup - activity = self.create_activity(mocker, ActivityTypes.message, text=magic_code) - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_FORMAT) - user_token_client.assert_not_called() - - @pytest.mark.asyncio - async def test_continue_flow_active_message_magic_code_error(self, mocker, sample_active_flow_state, user_token_client): - # setup - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - activity = self.create_activity(mocker, ActivityTypes.message, text="123456") - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_CODE_INCORRECT) - user_token_client.user_token.get_token.assert_called_once_with( - user_id=sample_active_flow_state.user_id, - connection_name=sample_active_flow_state.abs_oauth_connection_name, - channel_id=sample_active_flow_state.channel_id, - magic_code="123456" - ) - - @pytest.mark.asyncio - async def test_continue_flow_active_message_success(self, mocker, sample_active_flow_state, user_token_client): - # setup - activity = self.create_activity(mocker, ActivityTypes.message, text="123456") - await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) - user_token_client.user_token.get_token.assert_called_once_with( - user_id=sample_active_flow_state.user_id, - connection_name=sample_active_flow_state.abs_oauth_connection_name, - channel_id=sample_active_flow_state.channel_id, - magic_code="123456" - ) - - @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_verify_state_error(self, mocker, sample_active_flow_state, user_token_client): - # setup - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ - "state": "magic_code" - }) - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) - user_token_client.user_token.get_token.assert_called_once_with( - user_id=sample_active_flow_state.user_id, - connection_name=sample_active_flow_state.abs_oauth_connection_name, - channel_id=sample_active_flow_state.channel_id, - magic_code="magic_code" - ) - - @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_verify_success(self, mocker, sample_active_flow_state, user_token_client): - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ - "state": "magic_code" - }) - await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) - user_token_client.user_token.get_token.assert_called_once_with( - user_id=sample_active_flow_state.user_id, - connection_name=sample_active_flow_state.abs_oauth_connection_name, - channel_id=sample_active_flow_state.channel_id, - magic_code="magic_code" - ) - - @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_token_exchange_error(self, mocker, sample_active_flow_state, user_token_client): - token_exchange_request = {} - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse()) - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) - user_token_client.user_token.exchange_token.assert_called_once_with( - user_id=sample_active_flow_state.user_id, - connection_name=sample_active_flow_state.abs_oauth_connection_name, - channel_id=sample_active_flow_state.channel_id, - body=token_exchange_request - ) - - @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_token_exchange_success(self, mocker, sample_active_flow_state, user_token_client): - token_exchange_request = {} - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) - await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) - user_token_client.user_token.exchange_token.assert_called_once_with( - user_id=sample_active_flow_state.user_id, - connection_name=sample_active_flow_state.abs_oauth_connection_name, - channel_id=sample_active_flow_state.channel_id, - body=token_exchange_request - ) - - @pytest.mark.asyncio - async def test_continue_flow_invalid_invoke_name(self, mocker, sample_active_flow_state, user_token_client): - with pytest.raises(ValueError): - activity = self.create_activity(mocker, ActivityTypes.invoke, name="other", value={}) - flow = OAuthFlow(sample_active_flow_state, user_token_client) - await flow.continue_flow(activity) - - @pytest.mark.asyncio - async def test_continue_flow_invalid_activity_type(self, mocker, sample_active_flow_state, user_token_client): - with pytest.raises(ValueError): - activity = self.create_activity(mocker, ActivityTypes.command, name="other", value={}) - flow = OAuthFlow(sample_active_flow_state, user_token_client) - await flow.continue_flow(activity) - - # robrandao: TODO -> test begin_or_continue_flow \ No newline at end of file +import pytest + +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + TokenResponse, + SignInResource, + TokenExchangeState, + ConversationReference, + ChannelAccount, +) +from microsoft.agents.hosting.core.oauth import ( + OAuthFlow, + FlowErrorTag, + FlowStateTag +) +from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase +from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase + +# test constants +from .tools.testing_oauth import * + +class TestOAuthFlowUtils: + + def create_user_token_client(self, mocker, get_token_return=None): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + @pytest.fixture + def user_token_client(self, mocker): + return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) + + def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): + # def conv_ref(): + # return mocker.MagicMock(spec=ConversationReference) + mock_conversation_ref = mocker.MagicMock(ConversationReference) + mocker.patch.object(Activity, "get_conversation_reference", return_value=mocker.MagicMock(ConversationReference)) + # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) + return Activity( + type=activity_type, + name=name, + from_property=ChannelAccount(id=USER_ID), + channel_id=CHANNEL_ID, + # get_conversation_reference=mocker.Mock(return_value=conv_ref), + relates_to=mocker.MagicMock(ConversationReference), + value=value, + text=text + ) + + @pytest.fixture(params=FLOW_STATES.ALL()) + def sample_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=FLOW_STATES.FAILED()) + def sample_failed_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=FLOW_STATES.INACTIVE()) + def sample_inactive_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=FLOW_STATES.ACTIVE()) + def sample_active_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture + def flow(self, sample_flow_state, user_token_client): + return OAuthFlow(sample_flow_state, user_token_client) + + +class TestOAuthFlow(TestOAuthFlowUtils): + + def test_init_no_user_token_client(self, sample_flow_state): + with pytest.raises(ValueError): + OAuthFlow(sample_flow_state, None) + + @pytest.mark.parametrize("missing_value", [ + "connection", + "ms_app_id", + "channel_id", + "user_id" + ]) + def test_init_errors(self, missing_value, user_token_client): + flow_state = FLOW_STATES.STARTED_FLOW.model_copy() + flow_state.__setattr__(missing_value, None) + with pytest.raises(ValueError): + OAuthFlow(flow_state, user_token_client) + flow_state.__setattr__(missing_value, "") + with pytest.raises(ValueError): + OAuthFlow(flow_state, user_token_client) + + def test_init_with_state(self, sample_flow_state, user_token_client): + flow = OAuthFlow(sample_flow_state, user_token_client) + assert flow.flow_state == sample_flow_state + + def test_flow_state_prop_copy(self, flow): + flow_state = flow.flow_state + flow_state.user_id = (flow_state.user_id + "_copy") + assert flow.flow_state.user_id == USER_ID + assert flow_state.user_id == f"{USER_ID}_copy" + + @pytest.mark.asyncio + async def test_get_user_token_success(self, sample_flow_state, user_token_client): + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_final_flow_state = sample_flow_state + expected_final_flow_state.user_token = RES_TOKEN + + # test + token_response = await flow.get_user_token() + token = token_response.token + + # verify + assert token == RES_TOKEN + assert flow.flow_state == expected_final_flow_state + user_token_client.user_token.get_token.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + magic_code=None + ) + + @pytest.mark.asyncio + async def test_get_user_token_failure(self, mocker, sample_flow_state): + # setup + user_token_client = self.create_user_token_client(mocker, get_token_return=None) + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_final_flow_state = flow.flow_state # robrandao: TODO -> what happens if fails and has user_token? + + # test + token_response = await flow.get_user_token() + + # verify + assert token_response == TokenResponse() + assert flow.flow_state == expected_final_flow_state + user_token_client.user_token.get_token.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + magic_code=None + ) + + @pytest.mark.asyncio + async def test_sign_out(self, sample_flow_state, user_token_client): + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = "" + expected_flow_state.tag = FlowStateTag.NOT_STARTED + + # test + await flow.sign_out() + + # verify + user_token_client.user_token.sign_out.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID + ) + assert flow.flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_client): + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + activity = mocker.Mock(spec=Activity) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = RES_TOKEN + + # test + response = await flow.begin_flow(activity) + + # verify + flow_state = flow.flow_state + assert flow_state == expected_flow_state + + assert response.flow_state == flow_state + assert response.sign_in_resource is None # No sign-in resource in this case + assert response.flow_error_tag == FlowErrorTag.NONE + assert response.token_response.token == RES_TOKEN + user_token_client.user_token.get_token.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + # magic_code=None is an implementation detail, and ideally + # shouldn't be part of the test + magic_code=None + ) + + @pytest.mark.asyncio + async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_client): + # mock + # tes = mocker.Mock(TokenExchangeState) + # tes.get_encoded_state = mocker.Mock(return_value="encoded_state") + mocker.patch.object(TokenExchangeState, "get_encoded_state", return_value="encoded_state") + dummy_sign_in_resource = SignInResource( + sign_in_link="https://example.com/signin", + token_exchange_state=mocker.Mock( + TokenExchangeState, get_encoded_state=mocker.Mock(return_value="encoded_state") + ) + ) + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock( + return_value=dummy_sign_in_resource) + activity = self.create_activity(mocker) + + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = "" + expected_flow_state.tag = FlowStateTag.BEGIN + expected_flow_state.attempts_remaining = 3 + + # test + response = await flow.begin_flow(activity) + + # verify flow_state + flow_state = flow.flow_state + expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now + assert flow_state == response.flow_state + assert flow_state == expected_flow_state + + # verify FlowResponse + assert response.sign_in_resource == dummy_sign_in_resource + assert response.flow_error_tag == FlowErrorTag.NONE + assert not response.token_response + # robrandao: TODO more assertions on sign_in_resource + + @pytest.mark.asyncio + async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state, user_token_client): + # setup + activity = mocker.Mock() + flow = OAuthFlow(sample_inactive_flow_state, user_token_client) + expected_flow_state = sample_inactive_flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + + # verify + assert flow_state == expected_flow_state + assert flow_response.flow_state == flow_state + assert not flow_response.token_response + + async def helper_continue_flow_failure(self, active_flow_state, user_token_client, activity, flow_error_tag): + # setup + flow = OAuthFlow(active_flow_state, user_token_client) + expected_flow_state = active_flow_state + expected_flow_state.tag = FlowStateTag.CONTINUE if active_flow_state.attempts_remaining > 1 else FlowStateTag.FAILURE + expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining - 1 + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + + # verify + assert flow_response.flow_state == flow_state + assert expected_flow_state == flow_state + assert flow_response.token_response == TokenResponse() + assert flow_response.flow_error_tag == flow_error_tag + + async def helper_continue_flow_success(self, active_flow_state, user_token_client, activity): + # setup + flow = OAuthFlow(active_flow_state, user_token_client) + expected_flow_state = active_flow_state + expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.user_token = RES_TOKEN + expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now + + # verify + assert flow_response.flow_state == flow_state + assert expected_flow_state == flow_state + assert flow_response.token_response == TokenResponse(token=RES_TOKEN) + assert flow_response.flow_error_tag == FlowErrorTag.NONE + + @pytest.mark.asyncio + @pytest.mark.parametrize("magic_code", ["magic", "123", "", "1239453"]) + async def test_continue_flow_active_message_magic_format_error(self, mocker, sample_active_flow_state, user_token_client, magic_code): + # setup + activity = self.create_activity(mocker, ActivityTypes.message, text=magic_code) + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_FORMAT) + user_token_client.assert_not_called() + + @pytest.mark.asyncio + async def test_continue_flow_active_message_magic_code_error(self, mocker, sample_active_flow_state, user_token_client): + # setup + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + activity = self.create_activity(mocker, ActivityTypes.message, text="123456") + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_CODE_INCORRECT) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + magic_code="123456" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_message_success(self, mocker, sample_active_flow_state, user_token_client): + # setup + activity = self.create_activity(mocker, ActivityTypes.message, text="123456") + await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + magic_code="123456" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_verify_state_error(self, mocker, sample_active_flow_state, user_token_client): + # setup + user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ + "state": "magic_code" + }) + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + magic_code="magic_code" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_verify_success(self, mocker, sample_active_flow_state, user_token_client): + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ + "state": "magic_code" + }) + await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + magic_code="magic_code" + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_token_exchange_error(self, mocker, sample_active_flow_state, user_token_client): + token_exchange_request = {} + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse()) + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) + await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_token_exchange_success(self, mocker, sample_active_flow_state, user_token_client): + token_exchange_request = {} + user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) + activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) + await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request + ) + + @pytest.mark.asyncio + async def test_continue_flow_invalid_invoke_name(self, mocker, sample_active_flow_state, user_token_client): + with pytest.raises(ValueError): + activity = self.create_activity(mocker, ActivityTypes.invoke, name="other", value={}) + flow = OAuthFlow(sample_active_flow_state, user_token_client) + await flow.continue_flow(activity) + + @pytest.mark.asyncio + async def test_continue_flow_invalid_activity_type(self, mocker, sample_active_flow_state, user_token_client): + with pytest.raises(ValueError): + activity = self.create_activity(mocker, ActivityTypes.command, name="other", value={}) + flow = OAuthFlow(sample_active_flow_state, user_token_client) + await flow.continue_flow(activity) + + # robrandao: TODO -> test begin_or_continue_flow -> low priority for now \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py index c4fb6c27..0f9ef960 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py @@ -1,89 +1,89 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import asyncio -import uuid -from datetime import datetime, timezone -from typing import Callable, List, Optional, Awaitable -from collections import deque - -from microsoft.agents.hosting.core.authorization import ClaimsIdentity -from microsoft.agents.activity import ( - Activity, - ActivityTypes, - ChannelAccount, - ConversationAccount, - ConversationReference, - Channels, - ResourceResponse, - RoleTypes, - InvokeResponse, -) -from microsoft.agents.hosting.core.channel_adapter import ChannelAdapter -from microsoft.agents.hosting.core.turn_context import TurnContext -from microsoft.agents.hosting.core.connector import UserTokenClient - -AgentCallbackHandler = Callable[[TurnContext], Awaitable] - - -# patch userTokenclient constructor -class MockUserTokenClient(UserTokenClient): - """A mock user token client for testing.""" - - def __init__(self, ...): - self._store = {} - self._exchange_store = {} - self._throw_on_exchange = {} - self._user_token = mocker.Mock() - self._agent_sign_in = mocker.Mock() - - def add_user_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - token: str, - magic_code: str = None, - ): - """Add a token for a user that can be retrieved during testing.""" - key = self._get_key(connection_name, channel_id, user_id) - self._store[key] = (token, magic_code) - - def add_exchangeable_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - token: str, - ): - """Add an exchangeable token for a user that can be exchanged during testing.""" - key = self._get_exchange_key( - connection_name, channel_id, user_id, exchangeable_item - ) - self._exchange_store[key] = token - - def throw_on_exchange_request( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - ): - """Add an instruction to throw an exception during exchange requests.""" - key = self._get_exchange_key( - connection_name, channel_id, user_id, exchangeable_item - ) - self._throw_on_exchange[key] = True - - def _get_key(self, connection_name: str, channel_id: str, user_id: str) -> str: - return f"{connection_name}:{channel_id}:{user_id}" - - def _get_exchange_key( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - ) -> str: - return f"{connection_name}:{channel_id}:{user_id}:{exchangeable_item}" +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import uuid +from datetime import datetime, timezone +from typing import Callable, List, Optional, Awaitable +from collections import deque + +from microsoft.agents.hosting.core.authorization import ClaimsIdentity +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + ChannelAccount, + ConversationAccount, + ConversationReference, + Channels, + ResourceResponse, + RoleTypes, + InvokeResponse, +) +from microsoft.agents.hosting.core.channel_adapter import ChannelAdapter +from microsoft.agents.hosting.core.turn_context import TurnContext +from microsoft.agents.hosting.core.connector import UserTokenClient + +AgentCallbackHandler = Callable[[TurnContext], Awaitable] + + +# patch userTokenclient constructor +class MockUserTokenClient(UserTokenClient): + """A mock user token client for testing.""" + + def __init__(self, ...): + self._store = {} + self._exchange_store = {} + self._throw_on_exchange = {} + self._user_token = mocker.Mock() + self._agent_sign_in = mocker.Mock() + + def add_user_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + token: str, + magic_code: str = None, + ): + """Add a token for a user that can be retrieved during testing.""" + key = self._get_key(connection_name, channel_id, user_id) + self._store[key] = (token, magic_code) + + def add_exchangeable_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + token: str, + ): + """Add an exchangeable token for a user that can be exchanged during testing.""" + key = self._get_exchange_key( + connection_name, channel_id, user_id, exchangeable_item + ) + self._exchange_store[key] = token + + def throw_on_exchange_request( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + ): + """Add an instruction to throw an exception during exchange requests.""" + key = self._get_exchange_key( + connection_name, channel_id, user_id, exchangeable_item + ) + self._throw_on_exchange[key] = True + + def _get_key(self, connection_name: str, channel_id: str, user_id: str) -> str: + return f"{connection_name}:{channel_id}:{user_id}" + + def _get_exchange_key( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + ) -> str: + return f"{connection_name}:{channel_id}:{user_id}:{exchangeable_item}" diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py index 69f04ab3..ac184d46 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py @@ -1,248 +1,248 @@ -""" -Testing utilities for authorization functionality - -This module provides mock implementations and helper classes for testing authorization, -authentication, and token management scenarios. It includes test doubles for token -providers, connection managers, and authorization handlers that can be configured -to simulate various authentication states and flow conditions. -""" - -from microsoft.agents.hosting.core import ( - Connections, - AccessTokenProviderBase, - AuthHandler, - Authorization, - MemoryStorage, - oauth_flow, -) -from typing import Dict, Union -from microsoft.agents.hosting.core.authorization.agent_auth_configuration import ( - AgentAuthConfiguration, -) -from microsoft.agents.hosting.core.authorization.claims_identity import ClaimsIdentity - -from microsoft.agents.activity import TokenResponse - -from unittest.mock import Mock, AsyncMock - - -def create_test_auth_handler( - name: str, obo: bool = False, title: str = None, text: str = None -): - """ - Creates a test AuthHandler instance with standardized connection names. - - This helper function simplifies the creation of AuthHandler objects for testing - by automatically generating connection names based on the provided name and - optionally including On-Behalf-Of (OBO) connection configuration. - - Args: - name: Base name for the auth handler, used to generate connection names - obo: Whether to include On-Behalf-Of connection configuration - title: Optional title for the auth handler - text: Optional descriptive text for the auth handler - - Returns: - AuthHandler: Configured auth handler instance with test-friendly connection names - """ - return AuthHandler( - name, - abs_oauth_connection_name=f"{name}-abs-connection", - obo_connection_name=f"{name}-obo-connection" if obo else None, - title=title, - text=text, - ) - - -class TestingTokenProvider(AccessTokenProviderBase): - """ - Access token provider for unit tests. - - This test double simulates an access token provider that returns predictable - token values based on the provider name. It implements both standard token - acquisition and On-Behalf-Of (OBO) token flows for comprehensive testing - of authentication scenarios. - """ - - def __init__(self, name: str): - """ - Initialize the testing token provider. - - Args: - name: Identifier used to generate predictable token values - """ - self.name = name - - async def get_access_token( - self, resource_url: str, scopes: list[str], force_refresh: bool = False - ) -> str: - """ - Get an access token for the specified resource and scopes. - - Returns a predictable token string based on the provider name for testing. - - Args: (unused in test implementation) - resource_url: URL of the resource requiring authentication - scopes: List of OAuth scopes requested - force_refresh: Whether to force token refresh - - Returns: - str: Test token in format "{name}-token" - """ - return f"{self.name}-token" - - async def aquire_token_on_behalf_of( - self, scopes: list[str], user_assertion: str - ) -> str: - """ - Acquire a token on behalf of another user (OBO flow). - - Returns a predictable OBO token string for testing scenarios involving - delegated permissions and token exchange. - - Args: (unused in test implementation) - scopes: List of OAuth scopes requested for the OBO token - user_assertion: JWT token representing the user's identity - - Returns: - str: Test OBO token in format "{name}-obo-token" - """ - return f"{self.name}-obo-token" - - -class TestingConnectionManager(Connections): - """ - Connection manager for unit tests. - - This test double provides a simplified connection management interface that - returns TestingTokenProvider instances for all connection requests. It enables - testing of authorization flows without requiring actual OAuth configurations - or external authentication services. - """ - - def get_connection(self, connection_name: str) -> AccessTokenProviderBase: - """ - Get a token provider for the specified connection name. - - Args: - connection_name: Name of the OAuth connection - - Returns: - AccessTokenProviderBase: TestingTokenProvider configured with the connection name - """ - return TestingTokenProvider(connection_name) - - def get_default_connection(self) -> AccessTokenProviderBase: - """ - Get the default token provider. - - Returns: - AccessTokenProviderBase: TestingTokenProvider configured with "default" name - """ - return TestingTokenProvider("default") - - def get_token_provider( - self, claims_identity: ClaimsIdentity, service_url: str - ) -> AccessTokenProviderBase: - """ - Get a token provider based on claims identity and service URL. - - In this test implementation, returns the default connection regardless - of the provided parameters. - - Args: (unused in test implementation) - claims_identity: User's claims and identity information - service_url: URL of the service requiring authentication - - Returns: - AccessTokenProviderBase: The default TestingTokenProvider - """ - return self.get_default_connection() - - def get_default_connection_configuration(self) -> AgentAuthConfiguration: - """ - Get the default authentication configuration. - - Returns: - AgentAuthConfiguration: Empty configuration suitable for testing - """ - return AgentAuthConfiguration() - - -class TestingAuthorization(Authorization): - """ - Authorization system for comprehensive unit testing. - - This test double extends the Authorization class to provide a fully mocked - authorization environment suitable for testing various authentication scenarios. - It automatically configures auth handlers with mock OAuth flows that can simulate - different states like successful authentication, failed sign-in, or in-progress flows. - """ - - def __init__( - self, - auth_handlers: Dict[str, AuthHandler], - token: Union[str, None] = "default", - flow_started=False, - sign_in_failed=False, - ): - """ - Initialize the testing authorization system. - - Sets up a complete test authorization environment with memory storage, - test connection manager, and configures all provided auth handlers with - mock OAuth flows. - - Args: - auth_handlers: Dictionary mapping handler names to AuthHandler instances - token: Token value to use in mock responses. "default" uses auto-generated - tokens, None simulates no token available, or provide custom jwt token string - flow_started: Simulate OAuth flows that have already started - sign_in_failed: Simulate failed sign-in attempts - """ - # Initialize with test-friendly components - storage = MemoryStorage() - connection_manager = TestingConnectionManager() - super().__init__( - storage=storage, - auth_handlers=auth_handlers, - connection_manager=connection_manager, - service_url="a" - ) - - # Configure each auth handler with mock OAuth flow behavior - for auth_handler in self._auth_handlers.values(): - # Create default token response for this auth handler - default_token = TokenResponse( - connection_name=auth_handler.abs_oauth_connection_name, - token=f"{auth_handler.abs_oauth_connection_name}-token", - ) - - # Determine token response based on configuration - if token == "default": - token_response = default_token - elif token: - token_response = TokenResponse( - connection_name=auth_handler.abs_oauth_connection_name, - token=token, - ) - else: - token_response = None - - # Mock the OAuth flow with configurable behavior - auth_handler.flow = Mock( - get_user_token=AsyncMock(return_value=token_response), - _get_flow_state=AsyncMock( - # sign-in failed requires flow to be started - return_value=oauth_flow.FlowState( - flow_started=(flow_started or sign_in_failed) - ) - ), - begin_flow=AsyncMock(return_value=default_token), - # Mock flow continuation with optional failure simulation - continue_flow=AsyncMock( - return_value=None if sign_in_failed else default_token - ), - ) - - auth_handler.flow.flow_state = None +""" +Testing utilities for authorization functionality + +This module provides mock implementations and helper classes for testing authorization, +authentication, and token management scenarios. It includes test doubles for token +providers, connection managers, and authorization handlers that can be configured +to simulate various authentication states and flow conditions. +""" + +from microsoft.agents.hosting.core import ( + Connections, + AccessTokenProviderBase, + AuthHandler, + Authorization, + MemoryStorage, + OAuthFlow, +) +from typing import Dict, Union +from microsoft.agents.hosting.core.authorization.agent_auth_configuration import ( + AgentAuthConfiguration, +) +from microsoft.agents.hosting.core.authorization.claims_identity import ClaimsIdentity + +from microsoft.agents.activity import TokenResponse + +from unittest.mock import Mock, AsyncMock + + +def create_test_auth_handler( + name: str, obo: bool = False, title: str = None, text: str = None +): + """ + Creates a test AuthHandler instance with standardized connection names. + + This helper function simplifies the creation of AuthHandler objects for testing + by automatically generating connection names based on the provided name and + optionally including On-Behalf-Of (OBO) connection configuration. + + Args: + name: Base name for the auth handler, used to generate connection names + obo: Whether to include On-Behalf-Of connection configuration + title: Optional title for the auth handler + text: Optional descriptive text for the auth handler + + Returns: + AuthHandler: Configured auth handler instance with test-friendly connection names + """ + return AuthHandler( + name, + abs_oauth_connection_name=f"{name}-abs-connection", + obo_connection_name=f"{name}-obo-connection" if obo else None, + title=title, + text=text, + ) + + +class TestingTokenProvider(AccessTokenProviderBase): + """ + Access token provider for unit tests. + + This test double simulates an access token provider that returns predictable + token values based on the provider name. It implements both standard token + acquisition and On-Behalf-Of (OBO) token flows for comprehensive testing + of authentication scenarios. + """ + + def __init__(self, name: str): + """ + Initialize the testing token provider. + + Args: + name: Identifier used to generate predictable token values + """ + self.name = name + + async def get_access_token( + self, resource_url: str, scopes: list[str], force_refresh: bool = False + ) -> str: + """ + Get an access token for the specified resource and scopes. + + Returns a predictable token string based on the provider name for testing. + + Args: (unused in test implementation) + resource_url: URL of the resource requiring authentication + scopes: List of OAuth scopes requested + force_refresh: Whether to force token refresh + + Returns: + str: Test token in format "{name}-token" + """ + return f"{self.name}-token" + + async def aquire_token_on_behalf_of( + self, scopes: list[str], user_assertion: str + ) -> str: + """ + Acquire a token on behalf of another user (OBO flow). + + Returns a predictable OBO token string for testing scenarios involving + delegated permissions and token exchange. + + Args: (unused in test implementation) + scopes: List of OAuth scopes requested for the OBO token + user_assertion: JWT token representing the user's identity + + Returns: + str: Test OBO token in format "{name}-obo-token" + """ + return f"{self.name}-obo-token" + + +class TestingConnectionManager(Connections): + """ + Connection manager for unit tests. + + This test double provides a simplified connection management interface that + returns TestingTokenProvider instances for all connection requests. It enables + testing of authorization flows without requiring actual OAuth configurations + or external authentication services. + """ + + def get_connection(self, connection_name: str) -> AccessTokenProviderBase: + """ + Get a token provider for the specified connection name. + + Args: + connection_name: Name of the OAuth connection + + Returns: + AccessTokenProviderBase: TestingTokenProvider configured with the connection name + """ + return TestingTokenProvider(connection_name) + + def get_default_connection(self) -> AccessTokenProviderBase: + """ + Get the default token provider. + + Returns: + AccessTokenProviderBase: TestingTokenProvider configured with "default" name + """ + return TestingTokenProvider("default") + + def get_token_provider( + self, claims_identity: ClaimsIdentity, service_url: str + ) -> AccessTokenProviderBase: + """ + Get a token provider based on claims identity and service URL. + + In this test implementation, returns the default connection regardless + of the provided parameters. + + Args: (unused in test implementation) + claims_identity: User's claims and identity information + service_url: URL of the service requiring authentication + + Returns: + AccessTokenProviderBase: The default TestingTokenProvider + """ + return self.get_default_connection() + + def get_default_connection_configuration(self) -> AgentAuthConfiguration: + """ + Get the default authentication configuration. + + Returns: + AgentAuthConfiguration: Empty configuration suitable for testing + """ + return AgentAuthConfiguration() + + +class TestingAuthorization(Authorization): + """ + Authorization system for comprehensive unit testing. + + This test double extends the Authorization class to provide a fully mocked + authorization environment suitable for testing various authentication scenarios. + It automatically configures auth handlers with mock OAuth flows that can simulate + different states like successful authentication, failed sign-in, or in-progress flows. + """ + + def __init__( + self, + auth_handlers: Dict[str, AuthHandler], + token: Union[str, None] = "default", + flow_started=False, + sign_in_failed=False, + ): + """ + Initialize the testing authorization system. + + Sets up a complete test authorization environment with memory storage, + test connection manager, and configures all provided auth handlers with + mock OAuth flows. + + Args: + auth_handlers: Dictionary mapping handler names to AuthHandler instances + token: Token value to use in mock responses. "default" uses auto-generated + tokens, None simulates no token available, or provide custom jwt token string + flow_started: Simulate OAuth flows that have already started + sign_in_failed: Simulate failed sign-in attempts + """ + # Initialize with test-friendly components + storage = MemoryStorage() + connection_manager = TestingConnectionManager() + super().__init__( + storage=storage, + auth_handlers=auth_handlers, + connection_manager=connection_manager, + service_url="a" + ) + + # Configure each auth handler with mock OAuth flow behavior + for auth_handler in self._auth_handlers.values(): + # Create default token response for this auth handler + default_token = TokenResponse( + connection_name=auth_handler.abs_oauth_connection_name, + token=f"{auth_handler.abs_oauth_connection_name}-token", + ) + + # Determine token response based on configuration + if token == "default": + token_response = default_token + elif token: + token_response = TokenResponse( + connection_name=auth_handler.abs_oauth_connection_name, + token=token, + ) + else: + token_response = None + + # Mock the OAuth flow with configurable behavior + auth_handler.flow = Mock( + get_user_token=AsyncMock(return_value=token_response), + _get_flow_state=AsyncMock( + # sign-in failed requires flow to be started + return_value=oauth_flow.FlowState( + flow_started=(flow_started or sign_in_failed) + ) + ), + begin_flow=AsyncMock(return_value=default_token), + # Mock flow continuation with optional failure simulation + continue_flow=AsyncMock( + return_value=None if sign_in_failed else default_token + ), + ) + + auth_handler.flow.flow_state = None diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py index bf1aad03..595920ab 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py @@ -1,143 +1,149 @@ -from datetime import datetime - -from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag - -MS_APP_ID = "__ms_app_id" -CHANNEL_ID = "__channel_id" -USER_ID = "__user_id" -ABS_OAUTH_CONNECTION_NAME = "__connection_name" -RES_TOKEN = "__res_token" - -DEF_ARGS = { - "ms_app_id": MS_APP_ID, - "channel_id": CHANNEL_ID, - "user_id": USER_ID, - "abs_oauth_connection_name": ABS_OAUTH_CONNECTION_NAME -} - -class FLOW_STATES: - - STARTED_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.BEGIN, - attempts_remaining=1, - user_token="____", - expires_at=datetime.now().timestamp() + 1000000 - ) - STARTED_FLOW_ONE_RETRY = FlowState( - **DEF_ARGS, - tag=FlowStateTag.BEGIN, - attempts_remaining=2, - user_token="____", - expires_at=datetime.now().timestamp() + 1000000 - ) - ACTIVE_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.CONTINUE, - attempts_remaining=2, - user_token="__token", - expires_at=datetime.now().timestamp() + 1000000 - ) - ACTIVE_FLOW_ONE_RETRY = FlowState( - **DEF_ARGS, - tag=FlowStateTag.CONTINUE, - attempts_remaining=1, - user_token="__token", - expires_at=datetime.now().timestamp() + 1000000 - ) - ACTIVE_EXP_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.CONTINUE, - attempts_remaining=2, - user_token="__token", - expires_at=datetime.now().timestamp() - ) - COMPLETED_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.COMPLETE, - attempts_remaining=2, - user_token="test_token", - expires_at=datetime.now().timestamp() + 1000000 - ) - FAIL_BY_ATTEMPTS_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.FAILURE, - attempts_remaining=0, - expires_at=datetime.now().timestamp() + 1000000 - ) - - FAIL_BY_EXP_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.FAILURE, - attempts_remaining=2, - expires_at=0 - ) - - @staticmethod - def clone_state_list(lst): - return [ flow_state.model_copy() for flow_state in lst ] - - @staticmethod - def ALL(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.STARTED_FLOW, - FLOW_STATES.STARTED_FLOW_ONE_RETRY, - FLOW_STATES.ACTIVE_FLOW, - FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, - FLOW_STATES.ACTIVE_EXP_FLOW, - FLOW_STATES.COMPLETED_FLOW, - FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, - FLOW_STATES.FAIL_BY_EXP_FLOW - ]) - - @staticmethod - def FAILED(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.ACTIVE_EXP_FLOW, - FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, - FLOW_STATES.FAIL_BY_EXP_FLOW - ]) - - @staticmethod - def ACTIVE(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.STARTED_FLOW, - FLOW_STATES.STARTED_FLOW_ONE_RETRY, - FLOW_STATES.ACTIVE_FLOW, - FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, - ]) - - @staticmethod - def INACTIVE(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.ACTIVE_EXP_FLOW, - FLOW_STATES.COMPLETED_FLOW, - FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, - FLOW_STATES.FAIL_BY_EXP_FLOW - ]) - -def flow_key(channel_id, user_id, handler_id): - return f"auth/{channel_id}/{user_id}/{handler_id}" - -STORAGE_SAMPLE_DICT = { - "user_id": "123", - "session_id": "abc", - flow_key("webchat", "Alice", "graph"): FLOW_STATES.COMPLETED_FLOW.model_copy(), - flow_key("webchat", "Alice", "github"): FLOW_STATES.ACTIVE_FLOW_ONE_RETRY.model_copy(), - flow_key("teams", "Alice", "graph"): FLOW_STATES.STARTED_FLOW.model_copy(), - flow_key("webchat", "Bob", "graph"): FLOW_STATES.ACTIVE_EXP_FLOW.model_copy(), - flow_key("teams", "Bob", "slack"): FLOW_STATES.STARTED_FLOW.model_copy(), - flow_key("webchat", "Chuck", "github"): FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW.model_copy(), -} - -def STORAGE_INIT_DATA(): - data = STORAGE_SAMPLE_DICT.copy() - for key, value in data.items(): - data[key] = value.model_copy() if isinstance(value, FlowState) else value - return data - -def update_data_with_flow_state(data, channel_id, user_id, auth_handler_id, flow_state): - data = data.copy() - key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" - data[key] = flow_state.model_copy() +from datetime import datetime + +from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem +from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag + +MS_APP_ID = "__ms_app_id" +CHANNEL_ID = "__channel_id" +USER_ID = "__user_id" +ABS_OAUTH_CONNECTION_NAME = "__connection_name" +RES_TOKEN = "__res_token" + +DEF_ARGS = { + "ms_app_id": MS_APP_ID, + "channel_id": CHANNEL_ID, + "user_id": USER_ID, + "connection": ABS_OAUTH_CONNECTION_NAME +} + +class FLOW_STATES: + + STARTED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + user_token="____", + expires_at=datetime.now().timestamp() + 1000000 + ) + STARTED_FLOW_ONE_RETRY = FlowState( + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=2, + user_token="____", + expires_at=datetime.now().timestamp() + 1000000 + ) + ACTIVE_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expires_at=datetime.now().timestamp() + 1000000 + ) + ACTIVE_FLOW_ONE_RETRY = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + user_token="__token", + expires_at=datetime.now().timestamp() + 1000000 + ) + ACTIVE_EXP_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expires_at=datetime.now().timestamp() + ) + COMPLETED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.COMPLETE, + attempts_remaining=2, + user_token="test_token", + expires_at=datetime.now().timestamp() + 1000000 + ) + FAIL_BY_ATTEMPTS_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.FAILURE, + attempts_remaining=0, + expires_at=datetime.now().timestamp() + 1000000 + ) + + FAIL_BY_EXP_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.FAILURE, + attempts_remaining=2, + expires_at=0 + ) + + @staticmethod + def clone_state_list(lst): + return [ flow_state.model_copy() for flow_state in lst ] + + @staticmethod + def ALL(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW + ]) + + @staticmethod + def FAILED(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW + ]) + + @staticmethod + def ACTIVE(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + ]) + + @staticmethod + def INACTIVE(): + return FLOW_STATES.clone_state_list([ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW + ]) + +def flow_key(channel_id, user_id, handler_id): + return f"auth/{channel_id}/{user_id}/{handler_id}" + +def update_flow_state_handler(flow_state, handler): + flow_state = flow_state.model_copy() + flow_state.auth_handler_id = handler + return flow_state + +STORAGE_SAMPLE_DICT = { + "user_id": MockStoreItem({"id": "123"}), + "session_id": MockStoreItem({"id": "abc"}), + flow_key("webchat", "Alice", "graph"): update_flow_state_handler(FLOW_STATES.COMPLETED_FLOW.model_copy(), "graph"), + flow_key("webchat", "Alice", "github"): update_flow_state_handler(FLOW_STATES.ACTIVE_FLOW_ONE_RETRY.model_copy(), "github"), + flow_key("teams", "Alice", "graph"): update_flow_state_handler(FLOW_STATES.STARTED_FLOW.model_copy(), "graph"), + flow_key("webchat", "Bob", "graph"): update_flow_state_handler(FLOW_STATES.ACTIVE_EXP_FLOW.model_copy(), "graph"), + flow_key("teams", "Bob", "slack"): update_flow_state_handler(FLOW_STATES.STARTED_FLOW.model_copy(), "slack"), + flow_key("webchat", "Chuck", "github"): update_flow_state_handler(FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW.model_copy(), "github"), +} + +def STORAGE_INIT_DATA(): + data = STORAGE_SAMPLE_DICT.copy() + for key, value in data.items(): + data[key] = value.model_copy() if isinstance(value, FlowState) else value + return data + +def update_data_with_flow_state(data, channel_id, user_id, auth_handler_id, flow_state): + data = data.copy() + key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" + data[key] = flow_state.model_copy() return data \ No newline at end of file From e3a59e92af0317ac2ba5929646fb3750f650b711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Thu, 21 Aug 2025 13:25:52 -0700 Subject: [PATCH 16/32] Cleaning code and removing unused args --- .../agents/hosting/core/app/agent_application.py | 16 ++++++---------- .../hosting/core/app/oauth/authorization.py | 4 +--- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 39bed333..0f75bc42 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -29,10 +29,12 @@ from microsoft.agents.activity import ( Activity, ActivityTypes, + ActionTypes, ConversationUpdateTypes, MessageReactionTypes, MessageUpdateTypes, InvokeResponse, + TokenResponse, OAuthCard, Attachment, CardAction @@ -602,15 +604,11 @@ def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): return func async def _handle_flow_response(self, context: TurnContext, flow_response: FlowResponse) -> None: - flow_state: FlowState = flow_response.flow_state - in_flow_activity = flow_response.in_flow_activity - - if in_flow_activity: - await context.send_activity(in_flow_activity) if flow_state.tag == FlowStateTag.BEGIN: # Create the OAuth card + sign_in_resource = flow_response.sign_in_resource o_card: Attachment = CardFactory.oauth_card( OAuthCard( text=self.messages_configuration.get("card_title", "Sign in"), @@ -623,11 +621,10 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR channel_data=None, ) ], - token_exchange_resource=signing_resource.token_exchange_resource, - token_post_resource=signing_resource.token_post_resource, + token_exchange_resource=sign_in_resource.token_exchange_resource, + token_post_resource=sign_in_resource.token_post_resource, ) ) - # Send the card to the user await context.send_activity(MessageFactory.attachment(o_card)) elif flow_state.tag == FlowStateTag.FAILURE: @@ -669,14 +666,13 @@ async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnSt token_response: TokenResponse = new_flow_state.token_response saved_activity: Activity = new_flow_state.continuation_activity.model_copy() - if token_response and token_response.token: + if token_response: new_context = copy(context) new_context.activity = saved_activity logger.info( "Resending continuation activity %s", saved_activity.text ) await self.on_turn(new_context) - turn_state.delete_value(Authorization.SIGN_IN_STATE_KEY) # robrandao: TODOTODO await turn_state.save(context) return True diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 04375c91..3ba22f4e 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -308,7 +308,6 @@ async def begin_or_continue_flow( context: TurnContext, turn_state: TurnState, auth_handler_id: str, - sec_route: bool = True, ) -> FlowResponse: """ Begins or continues an OAuth flow. @@ -320,9 +319,8 @@ async def begin_or_continue_flow( Returns: The token response from the OAuth provider. + """ - # robrandao: TODO -> is_started_from_route and sec_route - async with self.open_flow(context, auth_handler_id) as flow: flow_response: FlowResponse = await flow.begin_or_continue_flow(context) From 51b6c17bb87885eaed555805cb16a67c3b07016b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Thu, 21 Aug 2025 16:10:05 -0700 Subject: [PATCH 17/32] Fixing test cases and implementing cached client --- .../microsoft/agents/hosting/core/__init__.py | 2 +- .../hosting/core/app/agent_application.py | 40 +++--- .../hosting/core/app/oauth/authorization.py | 133 ++++++++---------- .../agents/hosting/core/oauth/flow_state.py | 6 +- .../hosting/core/oauth/flow_storage_client.py | 45 ++++-- .../agents/hosting/core/oauth/oauth_flow.py | 120 ++++++++-------- 6 files changed, 188 insertions(+), 158 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py index 59ef6e1c..124de2cd 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py @@ -20,7 +20,7 @@ from .app.typing_indicator import TypingIndicator # App Auth -from .app.auth import ( +from .app.oauth import ( Authorization, AuthorizationHandlers, AuthHandler, diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 0f75bc42..bcae6f63 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -609,15 +609,29 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR if flow_state.tag == FlowStateTag.BEGIN: # Create the OAuth card sign_in_resource = flow_response.sign_in_resource + # for auth_handler in self._auth_handlers.values(): + # # Create OAuth flow with configuration + # messages_config = {} + # if auth_handler.title: + # ["card_title"] = auth_handler.title + # if auth_handler.text: + # messages_config["button_text"] = auth_handler.text + + # logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") + # auth_handler.flow = AuthFlow( + # storage=storage, + # abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, + # messages_configuration=messages_config if messages_config else None, + handler = self._auth.resolve_handler(flow_state.auth_handler_id) o_card: Attachment = CardFactory.oauth_card( OAuthCard( - text=self.messages_configuration.get("card_title", "Sign in"), + text="Sign in", connection_name=flow_state.connection, buttons=[ CardAction( - title=self.messages_configuration.get("button_text", "Sign in"), + title="Sign in", type=ActionTypes.signin, - value=signing_resource.sign_in_link, + value=sign_in_resource.sign_in_link, channel_data=None, ) ], @@ -630,28 +644,22 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR elif flow_state.tag == FlowStateTag.FAILURE: if flow_state.reached_max_retries(): await context.send_activity( - MessageFactory.text( - self.messages_configuration.get( - "max_retries_reached_messages", - "Sign-in failed. Please try again later.", - ) - ) + MessageFactory.text("Sign-in failed. Max retries reached. Please try again later.") ) elif flow_state.is_expired(): await context.send_activity( - MessageFactory.text( - self.messages_configuration.get( - "session_expired_messages", - "Sign-in session expired. Please try again.", - ) - )) + MessageFactory.text("Sign-in session expired. Please try again.") + ) else: logger.warning("Sign-in flow failed for unknown reasons.") await context.send_activity("Sign-in failed. Please try again.") async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: + print("*"*5) prev_flow_state = await self._auth.get_active_flow_state(context) + print(prev_flow_state) + print("*"*5) if self._auth and prev_flow_state: logger.debug("Sign-in flow is active for context: %s", context.activity.id) @@ -819,7 +827,7 @@ async def _on_activity(self, context: TurnContext, state: StateT): flow_response: FlowResponse = await self._auth.begin_or_continue_flow( context, state, auth_handler_id ) - await self._handle_flow_response(context, flow_response.in_flow_activity) + await self._handle_flow_response(context, flow_response) sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE if not sign_in_complete: break diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 3ba22f4e..8f57b80a 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -12,7 +12,7 @@ Connections, AccessTokenProviderBase, ) -from microsoft.agents.hosting.core.storage import Storage +from microsoft.agents.hosting.core.storage import Storage, MemoryStorage from microsoft.agents.activity import TokenResponse from microsoft.agents.hosting.core.connector.client import UserTokenClient @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) - class Authorization: """ Class responsible for managing authorization and OAuth flows. @@ -42,6 +41,7 @@ def __init__( connection_manager: Connections, auth_handlers: dict[str, AuthHandler] = None, auto_signin: bool = None, + use_cache: bool = False, **kwargs, ): """ @@ -56,24 +56,14 @@ def __init__( """ if not storage: raise ValueError("Storage is required for Authorization") - # if not auth_handlers: - # raise ValueError("At least one AuthHandler must be provided") - - # user_state = UserState(storage) - self.__storage = storage - self.__connection_manager = connection_manager + self._storage = storage + self._connection_manager = connection_manager auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( "USERAUTHORIZATION", {} ) - # self.__auto_signin = ( - # auto_signin - # if auto_signin is not None - # else auth_configuration.get("AUTOSIGNIN", False) - # ) - handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") if not auth_handlers and handlers_config: auth_handlers = { @@ -83,31 +73,19 @@ def __init__( for handler_name, config in handlers_config.items() } - self.__auth_handlers = auth_handlers or {} - self.__sign_in_success_handler: Optional[ + self._auth_handlers = auth_handlers or {} + self._sign_in_success_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] ] = None - self.__sign_in_failure_handler: Optional[ + self._sign_in_failure_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] ] = None - # # Configure each auth handler - # for auth_handler in self.__auth_handlers.values(): - # # Create OAuth flow with configuration - # messages_config = {} - # if auth_handler.title: - # ["card_title"] = auth_handler.title - # if auth_handler.text: - # messages_config["button_text"] = auth_handler.text - - # logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") - # auth_handler.flow = AuthFlow( - # storage=storage, - # abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, - # messages_configuration=messages_config if messages_config else None, - # ) - - def __ids_from_context(self, context: TurnContext) -> tuple[str, str]: + self._cache = None + if use_cache: + self._cache = MemoryStorage() + + def _ids_from_context(self, context: TurnContext) -> tuple[str, str]: """Checks and returns IDs necessary to load a new or existing flow. Raises a ValueError if channel ID or user ID are missing. @@ -121,7 +99,7 @@ def __ids_from_context(self, context: TurnContext) -> tuple[str, str]: return context.activity.channel_id, context.activity.from_property.id - async def __load_flow( + async def _load_flow( self, context: TurnContext, auth_handler_id: str = "" @@ -138,7 +116,7 @@ async def __load_flow( The FlowStorageClient corresponds to the channel and user info. The FlowState returned is the flow state for the given channel/user/handler - triple at the time of creating the flow. + triple at the time of reading from storage and before creating the flow. """ user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) @@ -146,15 +124,22 @@ async def __load_flow( auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) auth_handler_id = auth_handler.name - channel_id, user_id = self.__ids_from_context(context) + channel_id, user_id = self._ids_from_context(context) ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims["aud"] # try to load existing state - flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) + flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) + logger.info("Loading OAuth flow state from storage") flow_state: FlowState = await flow_storage_client.read(auth_handler_id) if not flow_state: + # breakpoint() + # print("\n"*3) + # print(channel_id, user_id, auth_handler_id, auth_handler.abs_oauth_connection_name, ms_app_id) + # print("\n"*3) + # breakpoint() + logger.info("No existing flow state found, creating new flow state") flow_state = FlowState( channel_id=channel_id, user_id=user_id, @@ -165,7 +150,7 @@ async def __load_flow( await flow_storage_client.write(flow_state) flow = OAuthFlow(flow_state, user_token_client) - return flow, flow_storage_client, flow_state + return flow, flow_storage_client @asynccontextmanager async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[OAuthFlow]: @@ -183,13 +168,10 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> As if not context: raise ValueError("context is required") - flow, flow_storage_client, init_flow_state = await self.__load_flow(context, auth_handler_id) + flow, flow_storage_client = await self._load_flow(context, auth_handler_id) yield flow - - # persist state - new_flow_state = flow.flow_state - if new_flow_state != init_flow_state: - await flow_storage_client.write(new_flow_state) + logger.info("Saving OAuth flow state to storage") + await flow_storage_client.write(flow.flow_state) async def get_token( self, context: TurnContext, auth_handler_id: str @@ -204,8 +186,9 @@ async def get_token( Returns: The token response from the OAuth provider. """ + logger.info(f"Getting token for auth handler: {auth_handler_id}") async with self.open_flow(context, auth_handler_id) as flow: - return await flow.get_user_token(context) + return await flow.get_user_token() async def exchange_token( self, @@ -224,11 +207,12 @@ async def exchange_token( Returns: The token response from the OAuth provider. """ + logger.info(f"Exchanging token for scopes: {scopes}") async with self.open_flow(context, auth_handler_id) as flow: token_response = await flow.get_user_token() - if token_response and self.__is_exchangeable(token_response.token): - return await self.__handle_obo(token_response.token, scopes, auth_handler_id) + if token_response and self._is_exchangeable(token_response.token): + return await self._handle_obo(token_response.token, scopes, auth_handler_id) return TokenResponse() @@ -239,12 +223,12 @@ async def exchange_token( # token_response = await auth_handler.flow.get_user_token(context) - # if self.__is_exchangeable(token_response.token if token_response else None): - # return await self.__handle_obo(token_response.token, scopes, auth_handler_id) + # if self._is_exchangeable(token_response.token if token_response else None): + # return await self._handle_obo(token_response.token, scopes, auth_handler_id) # return token_response - def __is_exchangeable(self, token: str) -> bool: + def _is_exchangeable(self, token: str) -> bool: """ Checks if a token is exchangeable (has api:// audience). @@ -263,7 +247,7 @@ def __is_exchangeable(self, token: str) -> bool: logger.exception("Failed to decode token to check audience") return False - async def __handle_obo( + async def _handle_obo( self, token: str, scopes: list[str], handler_id: str = None ) -> TokenResponse: """ @@ -279,7 +263,7 @@ async def __handle_obo( """ auth_handler = self.resolve_handler(handler_id) - token_provider: AccessTokenProviderBase = self.__connection_manager.get_connection( + token_provider: AccessTokenProviderBase = self._connection_manager.get_connection( auth_handler.obo_connection_name ) @@ -295,9 +279,10 @@ async def __handle_obo( async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: """Gets the first active flow state for the current context.""" - channel_id, user_id = self.__ids_from_context(context) - flow_storage_client = FlowStorageClient(channel_id, user_id, self.__storage) - for auth_handler_id in self.__auth_handlers.keys(): + channel_id, user_id = self._ids_from_context(context) + # TODO -> single read + flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) + for auth_handler_id in self._auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) if flow_state and flow_state.is_active(): return flow_state @@ -321,15 +306,17 @@ async def begin_or_continue_flow( The token response from the OAuth provider. """ + logger.debug("Beginning OAuth flow") async with self.open_flow(context, auth_handler_id) as flow: - flow_response: FlowResponse = await flow.begin_or_continue_flow(context) + flow_response: FlowResponse = await flow.begin_or_continue_flow(context.activity) flow_state: FlowState = flow_response.flow_state + # stayed completed TODO if flow_state.tag == FlowStateTag.COMPLETE: - self.__sign_in_success_handler(context, turn_state, flow_state.auth_handler_id) + self._sign_in_success_handler(context, turn_state, flow_state.auth_handler_id) elif flow_state.tag == FlowStateTag.FAILURE: - self.__sign_in_failure_handler(context, turn_state, flow_state.auth_handler_id, flow_response.flow_error_tag) + self._sign_in_failure_handler(context, turn_state, flow_state.auth_handler_id, flow_response.flow_error_tag) return flow_response @@ -344,15 +331,16 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: The resolved auth handler. """ if auth_handler_id: - if auth_handler_id not in self.__auth_handlers: + if auth_handler_id not in self._auth_handlers: + breakpoint() logger.error(f"Auth handler '{auth_handler_id}' not found") raise ValueError(f"Auth handler '{auth_handler_id}' not found") - return self.__auth_handlers[auth_handler_id] + return self._auth_handlers[auth_handler_id] # Return the first handler if no ID specified - return next(iter(self.__auth_handlers.values())) + return next(iter(self._auth_handlers.values())) - async def __sign_out( + async def _sign_out( self, context: TurnContext, auth_handler_ids: Iterable[str], @@ -366,11 +354,12 @@ async def __sign_out( Deletes the associated flow states from storage. """ for auth_handler_id in auth_handler_ids: - flow, flow_storage_client, initial_flow_state = await self.__load_flow(context, auth_handler_id) - if initial_flow_state: - logger.info(f"Signing out from handler: {auth_handler_id}") - await flow.sign_out() - await flow_storage_client.delete(auth_handler_id) + flow, flow_storage_client = await self._load_flow(context, auth_handler_id) + # ensure that the id is valid + self.resolve_handler(auth_handler_id) + logger.info(f"Signing out from handler: {auth_handler_id}") + await flow.sign_out() + await flow_storage_client.delete(auth_handler_id) async def sign_out( self, @@ -389,9 +378,9 @@ async def sign_out( Deletes the associated flow state(s) from storage. """ if auth_handler_id: - await self.__sign_out(context, [auth_handler_id]) + await self._sign_out(context, [auth_handler_id]) else: - await self.__sign_out(context, self.__auth_handlers.keys()) + await self._sign_out(context, self._auth_handlers.keys()) def on_sign_in_success( self, @@ -403,7 +392,7 @@ def on_sign_in_success( Args: handler: The handler function to call on successful sign-in. """ - self.__sign_in_success_handler = handler + self._sign_in_success_handler = handler def on_sign_in_failure( self, @@ -414,4 +403,4 @@ def on_sign_in_failure( Args: handler: The handler function to call on sign-in failure. """ - self.__sign_in_failure_handler = handler \ No newline at end of file + self._sign_in_failure_handler = handler \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index dc580537..f7065af0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -30,7 +30,7 @@ class FlowErrorTag(Enum): NONE = "none" MAGIC_FORMAT = "magic_format" MAGIC_CODE_INCORRECT = "magic_code_incorrect" - OTHER = "OTHER" + OTHER = "other" class FlowState(BaseModel, StoreItem): """Represents the state of an OAuthFlow""" @@ -43,7 +43,7 @@ class FlowState(BaseModel, StoreItem): connection: str = "" auth_handler_id: str = "" - expires_at: float = 0 + expiration: float = 0 continuation_activity: Optional[Activity] = None attempts_remaining: int = 0 tag: FlowStateTag = FlowStateTag.NOT_STARTED @@ -56,7 +56,7 @@ def from_json_to_store_item(json_data: dict) -> "FlowState": return FlowState.model_validate(json_data) def is_expired(self) -> bool: - return datetime.now().timestamp() >= self.expires_at + return datetime.now().timestamp() >= self.expiration def reached_max_attempts(self) -> bool: return self.attempts_remaining <= 0 diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py index 4ca41432..3cc126a9 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -3,10 +3,22 @@ from typing import Optional -from ..storage import Storage +from ..storage import Storage, MemoryStorage from .flow_state import FlowState +class DummyStorage(Storage): + + async def read(self, keys: list[str], **kwargs) -> dict[str, FlowState]: + return {} + + async def write(self, changes: dict[str, FlowState]) -> None: + pass + + async def delete(self, keys: list[str]) -> None: + pass + # this could be generalized, if needed +# not generally thread or async safe class FlowStorageClient: """Wrapper around Storage that manages sign-in state specific to each user and channel. @@ -17,7 +29,8 @@ def __init__( self, channel_id: str, user_id: str, - storage: Storage + storage: Storage, + cache_class: type[Storage] = None ): """ Parameters @@ -30,32 +43,44 @@ def __init__( if not user_id or not channel_id: raise ValueError("FlowStorageClient.__init__(): channel_id and user_id must be set.") - self.__base_key = f"auth/{channel_id}/{user_id}/" - self.__storage = storage + self._base_key = f"auth/{channel_id}/{user_id}/" + self._storage = storage + if cache_class is None: + cache_class = DummyStorage + self._cache = cache_class() @property def base_key(self) -> str: """Returns the prefix used for flow state storage isolation.""" - return self.__base_key + return self._base_key def key(self, auth_handler_id: str) -> str: """Creates a storage key for a specific sign-in handler.""" - return f"{self.__base_key}{auth_handler_id}" + return f"{self._base_key}{auth_handler_id}" async def read(self, auth_handler_id: str) -> Optional[FlowState]: """Reads the flow state for a specific authentication handler.""" key: str = self.key(auth_handler_id) - data = await self.__storage.read([key], target_cls=FlowState) + data = await self._cache.read([key], target_cls=FlowState) if key not in data: - return None + data = await self._storage.read([key], target_cls=FlowState) + if key not in data: + return None + await self._cache.write({key: data[key]}) return FlowState.model_validate(data.get(key)) async def write(self, value: FlowState) -> None: """Saves the flow state for a specific authentication handler.""" key: str = self.key(value.auth_handler_id) - await self.__storage.write({key: value}) + cached_state = await self._cache.read([key], target_cls=FlowState) + if not cached_state or cached_state != value: + await self._cache.write({key: value}) + await self._storage.write({key: value}) async def delete(self, auth_handler_id: str) -> None: """Deletes the flow state for a specific authentication handler.""" key: str = self.key(auth_handler_id) - await self.__storage.delete([key]) + cached_state = await self._cache.read([key], target_cls=FlowState) + if cached_state: + await self._cache.delete([key]) + await self._storage.delete([key]) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index 37e904b5..0db64651 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -68,20 +68,20 @@ def __init__( not flow_state.user_id): raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined") - self.__flow_state = flow_state.model_copy() - self.__abs_oauth_connection_name = self.__flow_state.connection - self.__ms_app_id = self.__flow_state.ms_app_id - self.__channel_id = self.__flow_state.channel_id - self.__user_id = self.__flow_state.user_id + self._flow_state = flow_state.model_copy() + self._abs_oauth_connection_name = self._flow_state.connection + self._ms_app_id = self._flow_state.ms_app_id + self._channel_id = self._flow_state.channel_id + self._user_id = self._flow_state.user_id - self.__user_token_client = user_token_client + self._user_token_client = user_token_client - self.__flow_duration = kwargs.get("flow_duration", 60000) # defaults to 60 seconds - self.__max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts + self._default_expires_in = kwargs.get("default_flow_duration", 60000) # default to 60 seconds + self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts @property def flow_state(self) -> FlowState: - return self.__flow_state.model_copy() + return self._flow_state.model_copy() async def get_user_token(self, magic_code: str = None) -> TokenResponse: """Get the user token based on the context. @@ -96,14 +96,22 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: Notes: flow_state.user_token is updated with the latest token. """ - token_response: TokenResponse = await self.__user_token_client.user_token.get_token( - user_id=self.__user_id, - connection_name=self.__abs_oauth_connection_name, - channel_id=self.__channel_id, - magic_code=magic_code - ) + if magic_code: + token_response: TokenResponse = await self._user_token_client.user_token.get_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + magic_code=magic_code + ) + else: + token_response: TokenResponse = await self._user_token_client.user_token.get_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + ) if token_response: - self.__flow_state.user_token = token_response.token + self._flow_state.user_token = token_response.token + self._flow_state.expiration = token_response.expiration return token_response async def sign_out(self) -> None: @@ -112,19 +120,19 @@ async def sign_out(self) -> None: Sets the flow state tag to NOT_STARTED Resets the flow state user_token field """ - await self.__user_token_client.user_token.sign_out( - user_id=self.__user_id, - connection_name=self.__abs_oauth_connection_name, - channel_id=self.__channel_id + await self._user_token_client.user_token.sign_out( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id ) - self.__flow_state.user_token = "" - self.__flow_state.tag = FlowStateTag.NOT_STARTED + self._flow_state.user_token = "" + self._flow_state.tag = FlowStateTag.NOT_STARTED - def __use_attempt(self) -> None: + def _use_attempt(self) -> None: """Decrements the remaining attempts for the flow, checking for failure.""" - self.__flow_state.attempts_remaining -= 1 - if self.__flow_state.attempts_remaining <= 0: - self.__flow_state.tag = FlowStateTag.FAILURE + self._flow_state.attempts_remaining -= 1 + if self._flow_state.attempts_remaining <= 0: + self._flow_state.tag = FlowStateTag.FAILURE async def begin_flow(self, activity: Activity) -> FlowResponse: """Begins the OAuthFlow. @@ -144,28 +152,29 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: token_response = await self.get_user_token() if token_response: return FlowResponse( - flow_state=self.__flow_state, + flow_state=self._flow_state, token_response=token_response ) - self.__flow_state.tag = FlowStateTag.BEGIN - self.__flow_state.expires_at = datetime.now().timestamp() + self.__flow_duration - self.__flow_state.attempts_remaining = self.__max_attempts - self.__flow_state.user_token = "" + self._flow_state.tag = FlowStateTag.BEGIN + self._flow_state.expiration = datetime.now().timestamp() + self._flow_state.attempts_remaining = self._max_attempts + self._flow_state.user_token = "" + self._flow_state.continuation_activity = activity.model_copy() token_exchange_state = TokenExchangeState( - connection_name=self.__abs_oauth_connection_name, + connection_name=self._abs_oauth_connection_name, conversation=activity.get_conversation_reference(), relates_to=activity.relates_to, - ms_app_id=self.__ms_app_id + ms_app_id=self._ms_app_id ) - sign_in_resource = await self.__user_token_client.agent_sign_in.get_sign_in_resource( + sign_in_resource = await self._user_token_client.agent_sign_in.get_sign_in_resource( state=token_exchange_state.get_encoded_state()) - return FlowResponse(flow_state=self.__flow_state, sign_in_resource=sign_in_resource) + return FlowResponse(flow_state=self._flow_state, sign_in_resource=sign_in_resource) - async def __continue_from_message(self, activity: Activity) -> tuple[TokenResponse, FlowErrorTag]: + async def _continue_from_message(self, activity: Activity) -> tuple[TokenResponse, FlowErrorTag]: """Handles the continuation of the flow from a message activity.""" magic_code: str = activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: @@ -178,20 +187,20 @@ async def __continue_from_message(self, activity: Activity) -> tuple[TokenRespon else: return TokenResponse(), FlowErrorTag.MAGIC_FORMAT - async def __continue_from_invoke_verify_state(self, activity: Activity) -> TokenResponse: + async def _continue_from_invoke_verify_state(self, activity: Activity) -> TokenResponse: """Handles the continuation of the flow from an invoke activity for verifying state.""" token_verify_state = activity.value magic_code: str = token_verify_state.get("state") token_response: TokenResponse = await self.get_user_token(magic_code) return token_response - async def __continue_from_invoke_token_exchange(self, activity: Activity) -> TokenResponse: + async def _continue_from_invoke_token_exchange(self, activity: Activity) -> TokenResponse: """Handles the continuation of the flow from an invoke activity for token exchange.""" token_exchange_request = activity.value - token_response = await self.__user_token_client.user_token.exchange_token( - user_id=self.__user_id, - connection_name=self.__abs_oauth_connection_name, - channel_id=self.__channel_id, + token_response = await self._user_token_client.user_token.exchange_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, body=token_exchange_request ) return token_response @@ -208,19 +217,19 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: """ logger.debug("Continuing auth flow...") - if not self.__flow_state.is_active(): - self.__flow_state.tag = FlowStateTag.FAILURE - return FlowResponse(flow_state=self.__flow_state) + if not self._flow_state.is_active(): + self._flow_state.tag = FlowStateTag.FAILURE + return FlowResponse(flow_state=self._flow_state) flow_error_tag = FlowErrorTag.NONE if activity.type == ActivityTypes.message: - token_response, flow_error_tag = await self.__continue_from_message(activity) + token_response, flow_error_tag = await self._continue_from_message(activity) elif (activity.type == ActivityTypes.invoke and activity.name == "signin/verifyState"): - token_response = await self.__continue_from_invoke_verify_state(activity) + token_response = await self._continue_from_invoke_verify_state(activity) elif (activity.type == ActivityTypes.invoke and activity.name == "signin/tokenExchange"): - token_response = await self.__continue_from_invoke_token_exchange(activity) + token_response = await self._continue_from_invoke_token_exchange(activity) else: raise ValueError("Unknown activity type") @@ -228,16 +237,15 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: flow_error_tag = FlowErrorTag.OTHER if flow_error_tag != FlowErrorTag.NONE: - self.__flow_state.tag = FlowStateTag.CONTINUE - self.__use_attempt() + self._flow_state.tag = FlowStateTag.CONTINUE + self._use_attempt() else: - self.__flow_state.tag = FlowStateTag.COMPLETE - self.__flow_state.expires_at = datetime.now().timestamp() + self.__flow_duration - self.__flow_state.user_token = token_response.token - + self._flow_state.tag = FlowStateTag.COMPLETE + self._flow_state.expiration = token_response.expiration + self._flow_state.user_token = token_response.token return FlowResponse( - flow_state=self.__flow_state.model_copy(), + flow_state=self._flow_state.model_copy(), flow_error_tag=flow_error_tag, token_response=token_response ) @@ -251,7 +259,7 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: Returns: A FlowResponse object containing the updated flow state and any token response. """ - if self.__flow_state.is_active(): + if self._flow_state.is_active(): return await self.continue_flow(activity) else: return await self.begin_flow(activity) \ No newline at end of file From 81a81696e4b355511f2ffa74a4bb20acebbe5ea9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 09:40:36 -0700 Subject: [PATCH 18/32] Added refresh() and adjusted tests --- .../agents/hosting/core/oauth/flow_state.py | 6 +- .../tests/test_flow_state.py | 61 +++++++++++++------ 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index f7065af0..01916b09 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -62,4 +62,8 @@ def reached_max_attempts(self) -> bool: return self.attempts_remaining <= 0 def is_active(self) -> bool: - return not self.is_expired() and not self.reached_max_attempts() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] \ No newline at end of file + return not self.is_expired() and not self.reached_max_attempts() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] + + def refresh(self): + if self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE, FlowStateTag.COMPLETE] and self.is_expired(): + self.tag = FlowStateTag.NOT_STARTED \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py index 0721b63f..d4f3c737 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py @@ -6,18 +6,41 @@ class TestFlowState: + @pytest.mark.parametrize( + "original_flow_state, refresh_to_not_started", + [ + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), + True), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), + True), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), + False), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp()), + False), + ] + ) + def test_refresh(self, original_flow_state, refresh_to_not_started): + new_flow_state = original_flow_state.model_copy() + new_flow_state.refresh() + expected_flow_state = original_flow_state.model_copy() + if refresh_to_not_started: + expected_flow_state.tag = FlowStateTag.NOT_STARTED + assert new_flow_state == expected_flow_state + @pytest.mark.parametrize( "flow_state, expected", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), True), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp()+1000), False), ] ) @@ -27,15 +50,15 @@ def test_is_expired(self, flow_state, expected): @pytest.mark.parametrize( "flow_state, expected", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()-100), False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp()), True), ] ) @@ -45,23 +68,23 @@ def test_reached_max_attempts(self, flow_state, expected): @pytest.mark.parametrize( "flow_state, expected", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), False), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expires_at=datetime.now().timestamp()), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expires_at=datetime.now().timestamp()-100), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()-100), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expiration=datetime.now().timestamp()-100), False), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=2, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=2, expiration=datetime.now().timestamp()+1000), True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=0, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=0, expiration=datetime.now().timestamp()+1000), False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=-1, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=-1, expiration=datetime.now().timestamp()+1000), False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), False), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expires_at=datetime.now().timestamp()+1000), + (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), True) ] ) From 3619c5d1ed1f16d62fc1ed3c77f4e219510715a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 09:55:39 -0700 Subject: [PATCH 19/32] Aligned more test cases --- .../hosting/core/app/oauth/authorization.py | 22 +-- .../hosting/core/oauth/flow_storage_client.py | 12 +- .../agents/hosting/core/oauth/oauth_flow.py | 21 ++- .../tests/test_oauth_flow.py | 125 +++++++++++++++++- .../tests/tools/testing_oauth.py | 24 ++-- 5 files changed, 169 insertions(+), 35 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 8f57b80a..7195f6ba 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -186,7 +186,7 @@ async def get_token( Returns: The token response from the OAuth provider. """ - logger.info(f"Getting token for auth handler: {auth_handler_id}") + logger.info("Getting token for auth handler: %s", auth_handler_id) async with self.open_flow(context, auth_handler_id) as flow: return await flow.get_user_token() @@ -207,7 +207,7 @@ async def exchange_token( Returns: The token response from the OAuth provider. """ - logger.info(f"Exchanging token for scopes: {scopes}") + logger.info("Exchanging token for scopes: %s", scopes) async with self.open_flow(context, auth_handler_id) as flow: token_response = await flow.get_user_token() @@ -292,7 +292,7 @@ async def begin_or_continue_flow( self, context: TurnContext, turn_state: TurnState, - auth_handler_id: str, + auth_handler_id: str = "", ) -> FlowResponse: """ Begins or continues an OAuth flow. @@ -306,16 +306,21 @@ async def begin_or_continue_flow( The token response from the OAuth provider. """ - logger.debug("Beginning OAuth flow") + if not auth_handler_id: + auth_handler_id = self.resolve_handler().name + + logger.debug("Beginning or continuing OAuth flow") async with self.open_flow(context, auth_handler_id) as flow: + prev_tag = flow.flow_state.state_tag flow_response: FlowResponse = await flow.begin_or_continue_flow(context.activity) flow_state: FlowState = flow_response.flow_state - # stayed completed TODO - if flow_state.tag == FlowStateTag.COMPLETE: + if flow_state.tag == FlowStateTag.COMPLETE and prev_tag != FlowStateTag.COMPLETE: + logger.debug("Calling Authorization sign in success handler") self._sign_in_success_handler(context, turn_state, flow_state.auth_handler_id) elif flow_state.tag == FlowStateTag.FAILURE: + logger.debug("Calling Authorization sign in failure handler") self._sign_in_failure_handler(context, turn_state, flow_state.auth_handler_id, flow_response.flow_error_tag) return flow_response @@ -332,8 +337,7 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: """ if auth_handler_id: if auth_handler_id not in self._auth_handlers: - breakpoint() - logger.error(f"Auth handler '{auth_handler_id}' not found") + logger.error("Auth handler '%s' not found", auth_handler_id) raise ValueError(f"Auth handler '{auth_handler_id}' not found") return self._auth_handlers[auth_handler_id] @@ -357,7 +361,7 @@ async def _sign_out( flow, flow_storage_client = await self._load_flow(context, auth_handler_id) # ensure that the id is valid self.resolve_handler(auth_handler_id) - logger.info(f"Signing out from handler: {auth_handler_id}") + logger.info("Signing out from handler: %s", auth_handler_id) await flow.sign_out() await flow_storage_client.delete(auth_handler_id) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py index 3cc126a9..384e2802 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -3,10 +3,10 @@ from typing import Optional -from ..storage import Storage, MemoryStorage +from ..storage import Storage from .flow_state import FlowState -class DummyStorage(Storage): +class DummyCache(Storage): async def read(self, keys: list[str], **kwargs) -> dict[str, FlowState]: return {} @@ -17,8 +17,10 @@ async def write(self, changes: dict[str, FlowState]) -> None: async def delete(self, keys: list[str]) -> None: pass -# this could be generalized, if needed -# not generally thread or async safe +# this could be generalized. Ideas: +# - CachedStorage class for two-tier storage +# - Namespaced/PrefixedStorage class for namespacing keying +# not generally thread or async safe (operations are not atomic) class FlowStorageClient: """Wrapper around Storage that manages sign-in state specific to each user and channel. @@ -46,7 +48,7 @@ def __init__( self._base_key = f"auth/{channel_id}/{user_id}/" self._storage = storage if cache_class is None: - cache_class = DummyStorage + cache_class = DummyCache self._cache = cache_class() @property diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index 0db64651..45ad6239 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -28,6 +28,7 @@ class FlowResponse(BaseModel): flow_error_tag: FlowErrorTag = FlowErrorTag.NONE token_response: Optional[TokenResponse] = None sign_in_resource: Optional[SignInResource] = None + continuation_activity: Optional[Activity] = None class OAuthFlow: """ @@ -69,6 +70,7 @@ def __init__( raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined") self._flow_state = flow_state.model_copy() + self._abs_oauth_connection_name = self._flow_state.connection self._ms_app_id = self._flow_state.ms_app_id self._channel_id = self._flow_state.channel_id @@ -133,7 +135,7 @@ def _use_attempt(self) -> None: self._flow_state.attempts_remaining -= 1 if self._flow_state.attempts_remaining <= 0: self._flow_state.tag = FlowStateTag.FAILURE - + async def begin_flow(self, activity: Activity) -> FlowResponse: """Begins the OAuthFlow. @@ -216,10 +218,11 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: """ logger.debug("Continuing auth flow...") - + if not self._flow_state.is_active(): + logger.debug("OAuth flow is not active, cannot continue") self._flow_state.tag = FlowStateTag.FAILURE - return FlowResponse(flow_state=self._flow_state) + return FlowResponse(flow_state=self._flow_state.model_copy(), token_response=None) flow_error_tag = FlowErrorTag.NONE if activity.type == ActivityTypes.message: @@ -247,7 +250,8 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: return FlowResponse( flow_state=self._flow_state.model_copy(), flow_error_tag=flow_error_tag, - token_response=token_response + token_response=token_response, + continuation_activity=self._flow_state.continuation_activity ) async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: @@ -259,7 +263,12 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: Returns: A FlowResponse object containing the updated flow state and any token response. """ + self._flow_state.refresh() + if self._flow_state.tag == FlowStateTag.COMPLETE: # robrandao: TODO -> test + logger.debug("OAuth flow has already been completed, nothing to do") + return FlowResponse(flow_state=self._flow_state.model_copy(), token_response=TokenResponse(token=self._flow_state.user_token)) + if self._flow_state.is_active(): return await self.continue_flow(activity) - else: - return await self.begin_flow(activity) \ No newline at end of file + + return await self.begin_flow(activity) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index 0a33c28b..a970349c 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -12,7 +12,8 @@ from microsoft.agents.hosting.core.oauth import ( OAuthFlow, FlowErrorTag, - FlowStateTag + FlowStateTag, + FlowResponse ) from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase @@ -68,6 +69,10 @@ def sample_failed_flow_state(self, request): @pytest.fixture(params=FLOW_STATES.INACTIVE()) def sample_inactive_flow_state(self, request): return request.param.model_copy() + + @pytest.fixture(params=[ flow_state for flow_state in FLOW_STATES.INACTIVE() if flow_state.tag != FlowStateTag.COMPLETE]) + def sample_inactive_flow_state_not_completed(self, request): + return request.param.model_copy() @pytest.fixture(params=FLOW_STATES.ACTIVE()) def sample_active_flow_state(self, request): @@ -115,6 +120,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client flow = OAuthFlow(sample_flow_state, user_token_client) expected_final_flow_state = sample_flow_state expected_final_flow_state.user_token = RES_TOKEN + expected_final_flow_state.expiration = None # test token_response = await flow.get_user_token() @@ -127,7 +133,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - magic_code=None + # magic_code=None ) @pytest.mark.asyncio @@ -147,7 +153,7 @@ async def test_get_user_token_failure(self, mocker, sample_flow_state): user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - magic_code=None + # magic_code=None ) @pytest.mark.asyncio @@ -176,6 +182,7 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ activity = mocker.Mock(spec=Activity) expected_flow_state = sample_flow_state expected_flow_state.user_token = RES_TOKEN + expected_flow_state.expiration = None # test response = await flow.begin_flow(activity) @@ -194,7 +201,7 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ channel_id=CHANNEL_ID, # magic_code=None is an implementation detail, and ideally # shouldn't be part of the test - magic_code=None + # magic_code=None ) @pytest.mark.asyncio @@ -220,13 +227,14 @@ async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_ expected_flow_state.user_token = "" expected_flow_state.tag = FlowStateTag.BEGIN expected_flow_state.attempts_remaining = 3 + expected_flow_state.continuation_activity = activity # test response = await flow.begin_flow(activity) # verify flow_state flow_state = flow.flow_state - expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now + expected_flow_state.expiration = flow_state.expiration # robrandao: TODO -> ignore this for now assert flow_state == response.flow_state assert flow_state == expected_flow_state @@ -281,7 +289,7 @@ async def helper_continue_flow_success(self, active_flow_state, user_token_clien # test flow_response = await flow.continue_flow(activity) flow_state = flow.flow_state - expected_flow_state.expires_at = flow_state.expires_at # robrandao: TODO -> ignore this for now + expected_flow_state.expiration = flow_state.expiration # robrandao: TODO -> ignore this for now # verify assert flow_response.flow_state == flow_state @@ -390,4 +398,107 @@ async def test_continue_flow_invalid_activity_type(self, mocker, sample_active_f flow = OAuthFlow(sample_active_flow_state, user_token_client) await flow.continue_flow(activity) - # robrandao: TODO -> test begin_or_continue_flow -> low priority for now \ No newline at end of file + @pytest.mark.asyncio + async def test_begin_or_continue_flow_not_started_flow( + self, + mocker + ): + # setup + sample_flow_state = FLOW_STATES.NOT_STARTED_FLOW.model_copy() + expected_response = FlowResponse( + flow_state = sample_flow_state, + token_response = TokenResponse(token=sample_flow_state.user_token), + ) + mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + + activity_mock = mocker.Mock() + flow = OAuthFlow(sample_flow_state, mocker.Mock()) + + # test + actual_response = await flow.begin_or_continue_flow(activity_mock) + + # verify + assert actual_response is expected_response + OAuthFlow.begin_flow.assert_called_once_with(activity_mock) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_inactive_flow( + self, + mocker, + sample_inactive_flow_state_not_completed, + ): + # setup + expected_response = FlowResponse( + flow_state = sample_inactive_flow_state_not_completed, + token_response = TokenResponse(), + ) + mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + + flow = OAuthFlow(sample_inactive_flow_state_not_completed, mocker.Mock()) + + # test + activity_mock = mocker.Mock() + actual_response = await flow.begin_or_continue_flow(activity_mock) + + # verify + assert actual_response is expected_response + OAuthFlow.begin_flow.assert_called_once_with(activity_mock) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_active_flow( + self, + mocker, + sample_active_flow_state, + ): + # setup + expected_response = FlowResponse( + flow_state = sample_active_flow_state, + token_response = TokenResponse(token=sample_active_flow_state.user_token), + ) + mocker.patch.object(OAuthFlow, "continue_flow", return_value=expected_response) + + flow = OAuthFlow(sample_active_flow_state, mocker.Mock()) + + # test + activity_mock = mocker.Mock() + actual_response = await flow.begin_or_continue_flow(activity_mock) + + # verify + assert actual_response is expected_response + OAuthFlow.continue_flow.assert_called_once_with(activity_mock) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_stale_flow_state( + self, + mocker + ): + flow_state = FLOW_STATES.ACTIVE_EXP_FLOW.model_copy() + expected_response = FlowResponse() + + mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + + flow = OAuthFlow(flow_state, mocker.Mock()) + actual_response = await flow.begin_or_continue_flow(None) + + assert actual_response is expected_response + OAuthFlow.begin_flow.assert_called_once_with(None) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_completed_flow_state( + self, + mocker + ): + flow_state = FLOW_STATES.COMPLETED_FLOW.model_copy() + expected_response = FlowResponse( + flow_state = flow_state, + token_response = TokenResponse(token=flow_state.user_token) + ) + mocker.patch.object(OAuthFlow, "begin_flow", return_value=None) + mocker.patch.object(OAuthFlow, "continue_flow", return_value=None) + + flow = OAuthFlow(flow_state, mocker.Mock()) + actual_response = await flow.begin_or_continue_flow(None) + + assert actual_response == expected_response + OAuthFlow.begin_flow.assert_not_called() + OAuthFlow.continue_flow.assert_not_called() \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py index 595920ab..9a9142d3 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py @@ -18,60 +18,68 @@ class FLOW_STATES: + NOT_STARTED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.NOT_STARTED, + attempts_remaining=1, + user_token="____", + expiration=datetime.now().timestamp() + 1000000 + ) + STARTED_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.BEGIN, attempts_remaining=1, user_token="____", - expires_at=datetime.now().timestamp() + 1000000 + expiration=datetime.now().timestamp() + 1000000 ) STARTED_FLOW_ONE_RETRY = FlowState( **DEF_ARGS, tag=FlowStateTag.BEGIN, attempts_remaining=2, user_token="____", - expires_at=datetime.now().timestamp() + 1000000 + expiration=datetime.now().timestamp() + 1000000 ) ACTIVE_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.CONTINUE, attempts_remaining=2, user_token="__token", - expires_at=datetime.now().timestamp() + 1000000 + expiration=datetime.now().timestamp() + 1000000 ) ACTIVE_FLOW_ONE_RETRY = FlowState( **DEF_ARGS, tag=FlowStateTag.CONTINUE, attempts_remaining=1, user_token="__token", - expires_at=datetime.now().timestamp() + 1000000 + expiration=datetime.now().timestamp() + 1000000 ) ACTIVE_EXP_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.CONTINUE, attempts_remaining=2, user_token="__token", - expires_at=datetime.now().timestamp() + expiration=datetime.now().timestamp() ) COMPLETED_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.COMPLETE, attempts_remaining=2, user_token="test_token", - expires_at=datetime.now().timestamp() + 1000000 + expiration=datetime.now().timestamp() + 1000000 ) FAIL_BY_ATTEMPTS_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.FAILURE, attempts_remaining=0, - expires_at=datetime.now().timestamp() + 1000000 + expiration=datetime.now().timestamp() + 1000000 ) FAIL_BY_EXP_FLOW = FlowState( **DEF_ARGS, tag=FlowStateTag.FAILURE, attempts_remaining=2, - expires_at=0 + expiration=0 ) @staticmethod From 0b12dafd01ac7e32d0f2d9fbebfdc6400efd2bff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 10:42:43 -0700 Subject: [PATCH 20/32] Fixed expiration time handling --- .../agents/activity/token_response.py | 11 + .../hosting/core/app/oauth/authorization.py | 2 +- .../agents/hosting/core/oauth/oauth_flow.py | 4 +- .../tests/old_test_authorization.py | 213 ------------------ .../tests/test_authorization.py | 30 ++- .../tests/test_oauth_flow.py | 4 +- .../tests/tools/mock_user_token_client.py | 4 +- 7 files changed, 44 insertions(+), 224 deletions(-) delete mode 100644 libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py diff --git a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py index 682c534b..64444c92 100644 --- a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py +++ b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from dateutil import parser + from .agents_model import AgentsModel from ._type_aliases import NonEmptyString @@ -23,3 +25,12 @@ class TokenResponse(AgentsModel): token: NonEmptyString = None expiration: NonEmptyString = None channel_id: NonEmptyString = None + + def __bool__(self): + return bool(self.token) + + @property + def expiration_timestamp(self) -> float: + if not self.expiration: + return 0.0 + return parser.isoparse(self.expiration).timestamp() \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 7195f6ba..0d1029e5 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -311,7 +311,7 @@ async def begin_or_continue_flow( logger.debug("Beginning or continuing OAuth flow") async with self.open_flow(context, auth_handler_id) as flow: - prev_tag = flow.flow_state.state_tag + prev_tag = flow.flow_state.tag flow_response: FlowResponse = await flow.begin_or_continue_flow(context.activity) flow_state: FlowState = flow_response.flow_state diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index 45ad6239..f4ab9830 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -113,7 +113,7 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: ) if token_response: self._flow_state.user_token = token_response.token - self._flow_state.expiration = token_response.expiration + self._flow_state.expiration = token_response.expiration_timestamp return token_response async def sign_out(self) -> None: @@ -244,7 +244,7 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: self._use_attempt() else: self._flow_state.tag = FlowStateTag.COMPLETE - self._flow_state.expiration = token_response.expiration + self._flow_state.expiration = token_response.expiration_timestamp self._flow_state.user_token = token_response.token return FlowResponse( diff --git a/libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py deleted file mode 100644 index e7ee8b69..00000000 --- a/libraries/microsoft-agents-hosting-core/tests/old_test_authorization.py +++ /dev/null @@ -1,213 +0,0 @@ -import pytest -from .tools.testing_authorization import ( - TestingAuthorization, - create_test_auth_handler, -) -from .tools.testing_utility import TestingUtility -import jwt -from unittest.mock import Mock, AsyncMock -from microsoft.agents.hosting.core import SignInState -from microsoft.agents.hosting.core.oauth_flow import FlowState - - -class TestAuthorization: - def setup_method(self): - self.turn_context = TestingUtility.create_empty_context() - - @pytest.mark.asyncio - async def test_get_token_single_handler(self): - """ - Test Authorization - get_token() with single Auth Handler - """ - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - } - ) - - token_res = await auth.get_token(self.turn_context) - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - @pytest.mark.asyncio - async def test_get_token_multiple_handlers(self): - """ - Test Authorization - get_token() with multiple Auth Handlers - """ - auth_handlers = { - "auth-handler": create_test_auth_handler("test-auth-a"), - "auth-handler-obo": create_test_auth_handler("test-auth-b", obo=True), - "auth-handler-with-title": create_test_auth_handler( - "test-auth-c", title="test-title" - ), - "auth-handler-with-title-text": create_test_auth_handler( - "test-auth-d", title="test-title", text="test-text" - ), - } - auth = TestingAuthorization(auth_handlers=auth_handlers) - for id, auth_handler in auth_handlers.items(): - # test value propogation - token_res = await auth.get_token(self.turn_context, id) - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - @pytest.mark.asyncio - async def test_exchange_token_valid_token(self): - valid_token = jwt.encode({"aud": "api://botframework.test.api"}, "") - scopes = ["scope-a"] - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth", obo=True), - }, - token=valid_token, - ) - token_res = await auth.exchange_token(self.turn_context, scopes=scopes) - assert ( - token_res.token - == f"{auth.resolver_handler().obo_connection_name}-obo-token" - ) - - @pytest.mark.asyncio - async def test_exchange_token_invalid_token(self): - invalid_token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - scopes = ["scope-a"] - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth"), - }, - token=invalid_token, - ) - token_res = await auth.exchange_token(self.turn_context, scopes=scopes) - assert token_res.token == invalid_token - - @pytest.mark.asyncio - async def test_get_flow_state_unavailable(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - } - ) - - assert auth.get_flow_state() == FlowState() - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_not_started(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.begin_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.set_value.assert_called_once_with( - auth.SIGN_IN_STATE_KEY, - SignInState( - continuation_activity=self.turn_context.activity, - handler_id="auth-handler", - ), - ) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_started(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - flow_started=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_started_sign_in_success(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - flow_started=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - auth.on_sign_in_success(AsyncMock()) - - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) - auth._sign_in_handler.assert_called_once_with( - self.turn_context, mock_turn_state, "auth-handler" - ) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_started_sign_in_failure(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - sign_in_failed=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - auth.on_sign_in_failure(AsyncMock()) - - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert not token_res - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - auth._sign_in_failed_handler.assert_called_once_with( - self.turn_context, mock_turn_state, "auth-handler" - ) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index 0bd5f91e..e820b40d 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -275,7 +275,7 @@ async def test_open_flow_success_modified_complete_flow( # verify actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expires_at = res_flow_state.expires_at # we won't check this for now + expected_flow_state.expiration = res_flow_state.expiration # we won't check this for now assert res_flow_state == expected_flow_state assert actual_flow_state == expected_flow_state @@ -310,7 +310,7 @@ async def test_open_flow_success_modified_failure( # verify actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + expected_flow_state.expiration = actual_flow_state.expiration # we won't check this for now assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT assert res_flow_state == expected_flow_state @@ -344,7 +344,7 @@ async def test_open_flow_success_modified_signout( # verify actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expires_at = actual_flow_state.expires_at # we won't check this for now + expected_flow_state.expiration = actual_flow_state.expiration # we won't check this for now assert actual_flow_state == expected_flow_state @pytest.mark.asyncio @@ -463,6 +463,30 @@ def on_sign_in_failure(context, turn_state, auth_handler_id, err): assert context.dummy_val == "github" assert flow_response.token_response == TokenResponse(token="token") + @pytest.mark.asyncio + async def test_begin_or_continue_flow_already_completed( + self, + mocker, + auth + ): + # robrandao: TODO -> lower priority -> more testing here + # setup + context = self.create_context(mocker, "webchat", "Alice") + + context.dummy_val = None + def on_sign_in_success(context, turn_state, auth_handler_id): + context.dummy_val = auth_handler_id + def on_sign_in_failure(context, turn_state, auth_handler_id, err): + context.dummy_val = str(err) + + # test + auth.on_sign_in_success(on_sign_in_success) + auth.on_sign_in_failure(on_sign_in_failure) + flow_response = await auth.begin_or_continue_flow(context, None, "graph") + assert context.dummy_val == None + assert flow_response.token_response == TokenResponse(token="test_token") + assert flow_response.continuation_activity is None + @pytest.mark.asyncio async def test_begin_or_continue_flow_failure( self, diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index a970349c..d1176395 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -120,7 +120,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client flow = OAuthFlow(sample_flow_state, user_token_client) expected_final_flow_state = sample_flow_state expected_final_flow_state.user_token = RES_TOKEN - expected_final_flow_state.expiration = None + expected_final_flow_state.expiration = 0.0 # test token_response = await flow.get_user_token() @@ -182,7 +182,7 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ activity = mocker.Mock(spec=Activity) expected_flow_state = sample_flow_state expected_flow_state.user_token = RES_TOKEN - expected_flow_state.expiration = None + expected_flow_state.expiration = 0.0 # test response = await flow.begin_flow(activity) diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py index 0f9ef960..c7ef14eb 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py @@ -30,12 +30,10 @@ class MockUserTokenClient(UserTokenClient): """A mock user token client for testing.""" - def __init__(self, ...): + def __init__(self): self._store = {} self._exchange_store = {} self._throw_on_exchange = {} - self._user_token = mocker.Mock() - self._agent_sign_in = mocker.Mock() def add_user_token( self, From 90f85428ef740d3edcece95d836226b8da22a4f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 11:36:48 -0700 Subject: [PATCH 21/32] Changed magic_code keyword to code in tests --- .../tests/test_oauth_flow.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index d1176395..05de428d 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -133,7 +133,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - # magic_code=None + code=None ) @pytest.mark.asyncio @@ -153,7 +153,7 @@ async def test_get_user_token_failure(self, mocker, sample_flow_state): user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - # magic_code=None + code=None ) @pytest.mark.asyncio @@ -199,9 +199,7 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - # magic_code=None is an implementation detail, and ideally - # shouldn't be part of the test - # magic_code=None + code=None ) @pytest.mark.asyncio @@ -315,7 +313,7 @@ async def test_continue_flow_active_message_magic_code_error(self, mocker, sampl user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - magic_code="123456" + code="123456" ) @pytest.mark.asyncio @@ -327,7 +325,7 @@ async def test_continue_flow_active_message_success(self, mocker, sample_active_ user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - magic_code="123456" + code="123456" ) @pytest.mark.asyncio @@ -342,7 +340,7 @@ async def test_continue_flow_active_sign_in_verify_state_error(self, mocker, sam user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - magic_code="magic_code" + code="magic_code" ) @pytest.mark.asyncio @@ -355,7 +353,7 @@ async def test_continue_flow_active_sign_in_verify_success(self, mocker, sample_ user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - magic_code="magic_code" + code="magic_code" ) @pytest.mark.asyncio From 91f2e61c20c1d162c39a0d3b106a2e953caa3bfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:11:41 -0700 Subject: [PATCH 22/32] Fixed expiration time issue --- .../hosting/core/app/agent_application.py | 49 ++++++++--------- .../agents/hosting/core/oauth/flow_state.py | 10 +++- .../agents/hosting/core/oauth/oauth_flow.py | 53 +++++++++++-------- .../tests/test_oauth_flow.py | 4 +- 4 files changed, 63 insertions(+), 53 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index bcae6f63..152878cb 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -609,20 +609,6 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR if flow_state.tag == FlowStateTag.BEGIN: # Create the OAuth card sign_in_resource = flow_response.sign_in_resource - # for auth_handler in self._auth_handlers.values(): - # # Create OAuth flow with configuration - # messages_config = {} - # if auth_handler.title: - # ["card_title"] = auth_handler.title - # if auth_handler.text: - # messages_config["button_text"] = auth_handler.text - - # logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") - # auth_handler.flow = AuthFlow( - # storage=storage, - # abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, - # messages_configuration=messages_config if messages_config else None, - handler = self._auth.resolve_handler(flow_state.auth_handler_id) o_card: Attachment = CardFactory.oauth_card( OAuthCard( text="Sign in", @@ -642,7 +628,7 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR # Send the card to the user await context.send_activity(MessageFactory.attachment(o_card)) elif flow_state.tag == FlowStateTag.FAILURE: - if flow_state.reached_max_retries(): + if flow_state.reached_max_attempts(): await context.send_activity( MessageFactory.text("Sign-in failed. Max retries reached. Please try again later.") ) @@ -655,23 +641,31 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR await context.send_activity("Sign-in failed. Please try again.") async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: - - print("*"*5) + logger.debug("Checking for active sign-in flow for context: %s with activity type %s", context.activity.id, context.activity.type) prev_flow_state = await self._auth.get_active_flow_state(context) - print(prev_flow_state) - print("*"*5) - if self._auth and prev_flow_state: + if prev_flow_state: + logger.debug("Previous flow state: %s", { + "user_id": prev_flow_state.user_id, + "connection": prev_flow_state.connection, + "channel_id": prev_flow_state.channel_id, + "auth_handler_id": prev_flow_state.auth_handler_id, + "tag": prev_flow_state.tag, + "expiration": prev_flow_state.expiration, + }) + if (prev_flow_state and + (prev_flow_state.tag == FlowStateTag.NOT_STARTED or prev_flow_state.is_active()) + and context.activity.type in [ActivityTypes.message, ActivityTypes.invoke]): logger.debug("Sign-in flow is active for context: %s", context.activity.id) flow_response: FlowResponse = await self._auth.begin_or_continue_flow( - context, turn_state, prev_flow_state.handler_id + context, turn_state, prev_flow_state.auth_handler_id ) await self._handle_flow_response(context, flow_response) new_flow_state: FlowState = flow_response.flow_state - token_response: TokenResponse = new_flow_state.token_response + token_response: TokenResponse = flow_response.token_response saved_activity: Activity = new_flow_state.continuation_activity.model_copy() if token_response: @@ -682,9 +676,8 @@ async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnSt ) await self.on_turn(new_context) await turn_state.save(context) - return True - - return False + return True # early return from _on_turn + return False # continue _on_turn async def on_turn(self, context: TurnContext): logger.debug( @@ -702,7 +695,7 @@ async def _on_turn(self, context: TurnContext): logger.debug("Initializing turn state") turn_state = await self._initialize_state(context) - if await self._on_turn_auth_intercept(context, turn_state): + if self._auth and await self._on_turn_auth_intercept(context, turn_state): return logger.debug("Running before turn middleware") @@ -726,12 +719,10 @@ async def _on_turn(self, context: TurnContext): ) await self._on_error(context, err) finally: - logger.debug("Stopping typing indicator") self.typing.stop() async def _start_typing(self, context: TurnContext): if self._options.start_typing_timer: - logger.debug("Starting typing indicator for context") await self.typing.start(context) def _remove_mentions(self, context: TurnContext): @@ -824,10 +815,12 @@ async def _on_activity(self, context: TurnContext, state: StateT): else: sign_in_complete = False for auth_handler_id in route.auth_handlers: + logger.debug("Beginning or continuing flow for auth handler %s", auth_handler_id) flow_response: FlowResponse = await self._auth.begin_or_continue_flow( context, state, auth_handler_id ) await self._handle_flow_response(context, flow_response) + logger.debug("Flow response flow_state.tag: %s", flow_response.flow_state.tag) sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE if not sign_in_complete: break diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index 01916b09..4533a18f 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -49,7 +49,15 @@ class FlowState(BaseModel, StoreItem): tag: FlowStateTag = FlowStateTag.NOT_STARTED def store_item_to_json(self) -> dict: - return self.model_dump() + data = self.model_dump() + if self.continuation_activity: + omit_if_empty = { + k + for k, v in self.continuation_activity + if isinstance(v, list) and not v + } + data["continuation_activity"] = {k: v for k, v in self.continuation_activity if k not in omit_if_empty and v is not None} + return data @staticmethod def from_json_to_store_item(json_data: dict) -> "FlowState": diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index f4ab9830..38b18436 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -69,6 +69,8 @@ def __init__( not flow_state.user_id): raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined") + logger.debug("Initializing OAuthFlow with flow state: %s", flow_state) + self._flow_state = flow_state.model_copy() self._abs_oauth_connection_name = self._flow_state.connection @@ -78,9 +80,13 @@ def __init__( self._user_token_client = user_token_client - self._default_expires_in = kwargs.get("default_flow_duration", 60000) # default to 60 seconds + self._default_flow_duration = kwargs.get("default_flow_duration", 300000) # default to 300 seconds self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts + logger.debug("OAuthFlow initialized with connection: %s, ms_app_id: %s, channel_id: %s, user_id: %s", + self._abs_oauth_connection_name, self._ms_app_id, self._channel_id, self._user_id) + logger.debug("Default flow duration: %d ms, Max attempts: %d", self._default_flow_duration, self._max_attempts) + @property def flow_state(self) -> FlowState: return self._flow_state.model_copy() @@ -98,22 +104,18 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: Notes: flow_state.user_token is updated with the latest token. """ - if magic_code: - token_response: TokenResponse = await self._user_token_client.user_token.get_token( - user_id=self._user_id, - connection_name=self._abs_oauth_connection_name, - channel_id=self._channel_id, - magic_code=magic_code - ) - else: - token_response: TokenResponse = await self._user_token_client.user_token.get_token( - user_id=self._user_id, - connection_name=self._abs_oauth_connection_name, - channel_id=self._channel_id, - ) + logger.info("Getting user token for user_id: %s, connection: %s", self._user_id, self._abs_oauth_connection_name) + token_response: TokenResponse = await self._user_token_client.user_token.get_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + code=magic_code + ) if token_response: + logger.info("User token obtained successfully: %s", token_response) self._flow_state.user_token = token_response.token - self._flow_state.expiration = token_response.expiration_timestamp + self._flow_state.expiration = datetime.now().timestamp() + self._default_flow_duration + return token_response async def sign_out(self) -> None: @@ -122,6 +124,7 @@ async def sign_out(self) -> None: Sets the flow state tag to NOT_STARTED Resets the flow state user_token field """ + logger.info("Signing out user_id: %s from connection: %s", self._user_id, self._abs_oauth_connection_name) await self._user_token_client.user_token.sign_out( user_id=self._user_id, connection_name=self._abs_oauth_connection_name, @@ -135,6 +138,7 @@ def _use_attempt(self) -> None: self._flow_state.attempts_remaining -= 1 if self._flow_state.attempts_remaining <= 0: self._flow_state.tag = FlowStateTag.FAILURE + logger.debug("Using an attempt for the OAuth flow. Attempts remaining after use: %d", self._flow_state.attempts_remaining) async def begin_flow(self, activity: Activity) -> FlowResponse: """Begins the OAuthFlow. @@ -148,9 +152,6 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: Notes: The flow state is reset if a token is not obtained from cache. """ - - # init flow state - token_response = await self.get_user_token() if token_response: return FlowResponse( @@ -158,8 +159,10 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: token_response=token_response ) + logger.debug("Starting new OAuth flow") self._flow_state.tag = FlowStateTag.BEGIN - self._flow_state.expiration = datetime.now().timestamp() + self._flow_state.expiration = datetime.now().timestamp() + self._default_flow_duration + self._flow_state.attempts_remaining = self._max_attempts self._flow_state.user_token = "" self._flow_state.continuation_activity = activity.model_copy() @@ -174,6 +177,8 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: sign_in_resource = await self._user_token_client.agent_sign_in.get_sign_in_resource( state=token_exchange_state.get_encoded_state()) + logger.debug("Sign-in resource obtained successfully: %s", sign_in_resource) + return FlowResponse(flow_state=self._flow_state, sign_in_resource=sign_in_resource) async def _continue_from_message(self, activity: Activity) -> tuple[TokenResponse, FlowErrorTag]: @@ -223,7 +228,7 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: logger.debug("OAuth flow is not active, cannot continue") self._flow_state.tag = FlowStateTag.FAILURE return FlowResponse(flow_state=self._flow_state.model_copy(), token_response=None) - + flow_error_tag = FlowErrorTag.NONE if activity.type == ActivityTypes.message: token_response, flow_error_tag = await self._continue_from_message(activity) @@ -234,18 +239,20 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: and activity.name == "signin/tokenExchange"): token_response = await self._continue_from_invoke_token_exchange(activity) else: - raise ValueError("Unknown activity type") + raise ValueError(f"Unknown activity type {activity.type}") if not token_response and flow_error_tag == FlowErrorTag.NONE: flow_error_tag = FlowErrorTag.OTHER if flow_error_tag != FlowErrorTag.NONE: + logger.debug("Flow error occurred: %s", flow_error_tag) self._flow_state.tag = FlowStateTag.CONTINUE self._use_attempt() else: self._flow_state.tag = FlowStateTag.COMPLETE - self._flow_state.expiration = token_response.expiration_timestamp + self._flow_state.expiration = datetime.now().timestamp() + self._default_flow_duration self._flow_state.user_token = token_response.token + logger.debug("OAuth flow completed successfully, got TokenResponse: %s", token_response) return FlowResponse( flow_state=self._flow_state.model_copy(), @@ -269,6 +276,8 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: return FlowResponse(flow_state=self._flow_state.model_copy(), token_response=TokenResponse(token=self._flow_state.user_token)) if self._flow_state.is_active(): + logger.debug("Active flow, continuing...") return await self.continue_flow(activity) + logger.debug("No active flow, beginning new flow...") return await self.begin_flow(activity) \ No newline at end of file diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index 05de428d..85261597 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -120,7 +120,6 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client flow = OAuthFlow(sample_flow_state, user_token_client) expected_final_flow_state = sample_flow_state expected_final_flow_state.user_token = RES_TOKEN - expected_final_flow_state.expiration = 0.0 # test token_response = await flow.get_user_token() @@ -128,6 +127,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client # verify assert token == RES_TOKEN + expected_final_flow_state.expiration = flow.flow_state.expiration assert flow.flow_state == expected_final_flow_state user_token_client.user_token.get_token.assert_called_once_with( user_id=USER_ID, @@ -182,13 +182,13 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ activity = mocker.Mock(spec=Activity) expected_flow_state = sample_flow_state expected_flow_state.user_token = RES_TOKEN - expected_flow_state.expiration = 0.0 # test response = await flow.begin_flow(activity) # verify flow_state = flow.flow_state + expected_flow_state.expiration = flow_state.expiration assert flow_state == expected_flow_state assert response.flow_state == flow_state From 6c3d5c2af60d8cade17a370d2dd3a5dd7d96abd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:13:21 -0700 Subject: [PATCH 23/32] Undid TokenResponse changes --- .../microsoft/agents/activity/token_response.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py index 64444c92..682c534b 100644 --- a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py +++ b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from dateutil import parser - from .agents_model import AgentsModel from ._type_aliases import NonEmptyString @@ -25,12 +23,3 @@ class TokenResponse(AgentsModel): token: NonEmptyString = None expiration: NonEmptyString = None channel_id: NonEmptyString = None - - def __bool__(self): - return bool(self.token) - - @property - def expiration_timestamp(self) -> float: - if not self.expiration: - return 0.0 - return parser.isoparse(self.expiration).timestamp() \ No newline at end of file From 3e473096a5c42afda47778c8f4b51a9b90ffd216 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:14:23 -0700 Subject: [PATCH 24/32] Updated logging --- .../hosting/core/app/oauth/authorization.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 0d1029e5..1f9e3a80 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -76,10 +76,10 @@ def __init__( self._auth_handlers = auth_handlers or {} self._sign_in_success_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = None + ] = lambda *args: None self._sign_in_failure_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = None + ] = lambda *args: None self._cache = None if use_cache: @@ -134,11 +134,6 @@ async def _load_flow( flow_state: FlowState = await flow_storage_client.read(auth_handler_id) if not flow_state: - # breakpoint() - # print("\n"*3) - # print(channel_id, user_id, auth_handler_id, auth_handler.abs_oauth_connection_name, ms_app_id) - # print("\n"*3) - # breakpoint() logger.info("No existing flow state found, creating new flow state") flow_state = FlowState( channel_id=channel_id, @@ -166,6 +161,7 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> As if not yet present in storage. """ if not context: + logger.error("No context provided to open_flow") raise ValueError("context is required") flow, flow_storage_client = await self._load_flow(context, auth_handler_id) @@ -212,6 +208,7 @@ async def exchange_token( token_response = await flow.get_user_token() if token_response and self._is_exchangeable(token_response.token): + logger.debug("Token is exchangeable, performing OBO flow") return await self._handle_obo(token_response.token, scopes, auth_handler_id) return TokenResponse() @@ -244,7 +241,7 @@ def _is_exchangeable(self, token: str) -> bool: aud = payload.get("aud") return isinstance(aud, str) and aud.startswith("api://") except Exception: - logger.exception("Failed to decode token to check audience") + logger.error("Failed to decode token to check audience") return False async def _handle_obo( @@ -279,8 +276,8 @@ async def _handle_obo( async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: """Gets the first active flow state for the current context.""" + logger.debug("Getting active flow state") channel_id, user_id = self._ids_from_context(context) - # TODO -> single read flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) for auth_handler_id in self._auth_handlers.keys(): flow_state = await flow_storage_client.read(auth_handler_id) From 6127b157ec777facd3b043b04294266d09301ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:17:17 -0700 Subject: [PATCH 25/32] Adding some comments --- .../microsoft/agents/hosting/core/app/agent_application.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 152878cb..175e82e3 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -604,6 +604,7 @@ def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): return func async def _handle_flow_response(self, context: TurnContext, flow_response: FlowResponse) -> None: + """Handles CONTINUE and FAILURE flow responses, sending activities back.""" flow_state: FlowState = flow_response.flow_state if flow_state.tag == FlowStateTag.BEGIN: @@ -641,6 +642,7 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR await context.send_activity("Sign-in failed. Please try again.") async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: + """Intercepts the turn to check for active authentication flows.""" logger.debug("Checking for active sign-in flow for context: %s with activity type %s", context.activity.id, context.activity.type) prev_flow_state = await self._auth.get_active_flow_state(context) if prev_flow_state: @@ -652,6 +654,9 @@ async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnSt "tag": prev_flow_state.tag, "expiration": prev_flow_state.expiration, }) + # proceed if there is an existing flow to continue + # new flows should be initiated in _on_activity + # this can be reorganized later... but it works for now if (prev_flow_state and (prev_flow_state.tag == FlowStateTag.NOT_STARTED or prev_flow_state.is_active()) and context.activity.type in [ActivityTypes.message, ActivityTypes.invoke]): From a21df6ab3862ab4b296d5c5ab6364b0a46068d55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:22:37 -0700 Subject: [PATCH 26/32] Removed unused cache field --- .../microsoft/agents/hosting/core/app/oauth/authorization.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 1f9e3a80..669f0f00 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -81,10 +81,6 @@ def __init__( Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] ] = lambda *args: None - self._cache = None - if use_cache: - self._cache = MemoryStorage() - def _ids_from_context(self, context: TurnContext) -> tuple[str, str]: """Checks and returns IDs necessary to load a new or existing flow. From d051be0c0f24e57cc86971214fb8b51c06deb529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:36:44 -0700 Subject: [PATCH 27/32] Added back the __bool__ op on TokenResponse --- .../microsoft/agents/activity/token_response.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py index 682c534b..c027f0a7 100644 --- a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py +++ b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py @@ -23,3 +23,6 @@ class TokenResponse(AgentsModel): token: NonEmptyString = None expiration: NonEmptyString = None channel_id: NonEmptyString = None + + def __bool__(self): + return bool(self.token) \ No newline at end of file From 3b243835d8764b23c7a6544ea63361d3cd2c1534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:46:53 -0700 Subject: [PATCH 28/32] Revised some comments --- .../hosting/core/app/oauth/authorization.py | 32 +++++-------------- .../agents/hosting/core/oauth/flow_state.py | 10 +----- .../hosting/core/oauth/flow_storage_client.py | 11 ++++--- .../tests/test_oauth_flow.py | 2 +- 4 files changed, 16 insertions(+), 39 deletions(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 669f0f00..3cf2ffc7 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -99,7 +99,7 @@ async def _load_flow( self, context: TurnContext, auth_handler_id: str = "" - ) -> tuple[OAuthFlow, FlowStorageClient, FlowState]: + ) -> tuple[OAuthFlow, FlowStorageClient]: """Loads the OAuth flow for a specific auth handler. Args: @@ -109,10 +109,7 @@ async def _load_flow( Returns: The OAuthFlow returned corresponds to the flow associated with the chosen handler, and the channel and user info found in the context. - - The FlowStorageClient corresponds to the channel and user info. - The FlowState returned is the flow state for the given channel/user/handler - triple at the time of reading from storage and before creating the flow. + The FlowStorageClient corresponds to the same channel and user info. """ user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) @@ -150,6 +147,7 @@ async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> As Args: context: The context object for the current turn. auth_handler_id: ID of the auth handler to use. + If none provided, uses the first handler. Yields: OAuthFlow: @@ -173,7 +171,7 @@ async def get_token( Args: context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. Returns: The token response from the OAuth provider. @@ -209,18 +207,6 @@ async def exchange_token( return TokenResponse() - # auth_handler = self.resolver_handler(auth_handler_id) - # if not auth_handler.flow: - # logger.error("OAuth flow is not configured for the auth handler") - # raise ValueError("OAuth flow is not configured for the auth handler") - - # token_response = await auth_handler.flow.get_user_token(context) - - # if self._is_exchangeable(token_response.token if token_response else None): - # return await self._handle_obo(token_response.token, scopes, auth_handler_id) - - # return token_response - def _is_exchangeable(self, token: str) -> bool: """ Checks if a token is exchangeable (has api:// audience). @@ -287,12 +273,11 @@ async def begin_or_continue_flow( turn_state: TurnState, auth_handler_id: str = "", ) -> FlowResponse: - """ - Begins or continues an OAuth flow. + """Begins or continues an OAuth flow. Args: context: The context object for the current turn. - state: The state object for the current turn. + turn_state: The state object for the current turn. auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. Returns: @@ -319,8 +304,7 @@ async def begin_or_continue_flow( return flow_response def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: - """ - Resolves the auth handler to use based on the provided ID. + """Resolves the auth handler to use based on the provided ID. Args: auth_handler_id: Optional ID of the auth handler to resolve, defaults to first handler. @@ -346,7 +330,7 @@ async def _sign_out( Args: context: The context object for the current turn. - auth_handler_ids: List of auth handler IDs to sign out from. + auth_handler_ids: Iterable of auth handler IDs to sign out from. Deletes the associated flow states from storage. """ diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index 4533a18f..012892aa 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -49,15 +49,7 @@ class FlowState(BaseModel, StoreItem): tag: FlowStateTag = FlowStateTag.NOT_STARTED def store_item_to_json(self) -> dict: - data = self.model_dump() - if self.continuation_activity: - omit_if_empty = { - k - for k, v in self.continuation_activity - if isinstance(v, list) and not v - } - data["continuation_activity"] = {k: v for k, v in self.continuation_activity if k not in omit_if_empty and v is not None} - return data + return self.model_dump(mode="json", exclude_unset=True, by_alias=True) @staticmethod def from_json_to_store_item(json_data: dict) -> "FlowState": diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py index 384e2802..74acc117 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -35,11 +35,12 @@ def __init__( cache_class: type[Storage] = None ): """ - Parameters - context: The TurnContext for the current conversation. Used to isolate - data across channels and users. This defines the prefix used to - access storage. - storage: The Storage instance used to persist flow state data. + Args: + channel_id: used to create the prefix + user_id: used to create the prefix + storage: the backing storage + cache_class: the cache class to use (defaults to DummyCache, which performs no caching). + This cache's lifetime is tied to the FlowStorageClient instance. """ if not user_id or not channel_id: diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index 85261597..f4cd727d 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -141,7 +141,7 @@ async def test_get_user_token_failure(self, mocker, sample_flow_state): # setup user_token_client = self.create_user_token_client(mocker, get_token_return=None) flow = OAuthFlow(sample_flow_state, user_token_client) - expected_final_flow_state = flow.flow_state # robrandao: TODO -> what happens if fails and has user_token? + expected_final_flow_state = flow.flow_state # test token_response = await flow.get_user_token() From 014ce86747cb5ee9e4207cc0779d93fcf246e72b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:49:53 -0700 Subject: [PATCH 29/32] Updated default flow duration to 10 minutes --- .../microsoft/agents/hosting/core/oauth/oauth_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index 38b18436..42055883 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -80,7 +80,7 @@ def __init__( self._user_token_client = user_token_client - self._default_flow_duration = kwargs.get("default_flow_duration", 300000) # default to 300 seconds + self._default_flow_duration = kwargs.get("default_flow_duration", 10 * 60 * 1000) # default to 10 minutes self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts logger.debug("OAuthFlow initialized with connection: %s, ms_app_id: %s, channel_id: %s, user_id: %s", From be7c8db97745df6f341895ce61c085c87b35c48b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 14:55:14 -0700 Subject: [PATCH 30/32] Changing default flow duration --- .../microsoft/agents/hosting/core/oauth/oauth_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index 42055883..86016a87 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -80,7 +80,7 @@ def __init__( self._user_token_client = user_token_client - self._default_flow_duration = kwargs.get("default_flow_duration", 10 * 60 * 1000) # default to 10 minutes + self._default_flow_duration = kwargs.get("default_flow_duration", 10 * 60) # default to 10 minutes self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts logger.debug("OAuthFlow initialized with connection: %s, ms_app_id: %s, channel_id: %s, user_id: %s", From 45ec2f455d8efebba65c87832d858cead61a7967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 15:05:50 -0700 Subject: [PATCH 31/32] Reformatted with black --- .../agents/activity/token_response.py | 2 +- .../microsoft/agents/hosting/core/__init__.py | 6 +- .../hosting/core/app/agent_application.py | 82 +++-- .../agents/hosting/core/app/oauth/__init__.py | 9 +- .../hosting/core/app/oauth/auth_handler.py | 2 + .../hosting/core/app/oauth/authorization.py | 75 +++-- .../agents/hosting/core/oauth/__init__.py | 10 +- .../agents/hosting/core/oauth/flow_state.py | 26 +- .../hosting/core/oauth/flow_storage_client.py | 10 +- .../agents/hosting/core/oauth/oauth_flow.py | 194 +++++++---- .../core/storage/storage_test_utils.py | 2 + .../tests/test_authorization.py | 317 ++++++++++-------- .../tests/test_flow_state.py | 251 +++++++++++--- .../tests/test_flow_storage_client.py | 85 +++-- .../tests/test_oauth_flow.py | 300 +++++++++++------ .../tests/tools/testing_adapter.py | 1 + .../tests/tools/testing_authorization.py | 2 +- .../tests/tools/testing_oauth.py | 201 ++++++----- 18 files changed, 998 insertions(+), 577 deletions(-) diff --git a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py index c027f0a7..00d6aa91 100644 --- a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py +++ b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py @@ -25,4 +25,4 @@ class TokenResponse(AgentsModel): channel_id: NonEmptyString = None def __bool__(self): - return bool(self.token) \ No newline at end of file + return bool(self.token) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py index 124de2cd..f5d07cef 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py @@ -42,14 +42,14 @@ from .authorization.jwt_token_validator import JwtTokenValidator from .authorization.auth_types import AuthTypes -#OAuth +# OAuth from .oauth import ( FlowState, FlowStateTag, FlowErrorTag, FlowResponse, FlowStorageClient, - OAuthFlow + OAuthFlow, ) # Client API @@ -167,5 +167,5 @@ "FlowErrorTag", "FlowResponse", "FlowStorageClient", - "OAuthFlow" + "OAuthFlow", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 175e82e3..6084a36c 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -37,7 +37,7 @@ TokenResponse, OAuthCard, Attachment, - CardAction + CardAction, ) from .. import CardFactory, MessageFactory @@ -602,11 +602,13 @@ def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): logger.debug(f"Setting custom turn state factory: {func.__name__}") self._turn_state_factory = func return func - - async def _handle_flow_response(self, context: TurnContext, flow_response: FlowResponse) -> None: + + async def _handle_flow_response( + self, context: TurnContext, flow_response: FlowResponse + ) -> None: """Handles CONTINUE and FAILURE flow responses, sending activities back.""" flow_state: FlowState = flow_response.flow_state - + if flow_state.tag == FlowStateTag.BEGIN: # Create the OAuth card sign_in_resource = flow_response.sign_in_resource @@ -631,7 +633,9 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR elif flow_state.tag == FlowStateTag.FAILURE: if flow_state.reached_max_attempts(): await context.send_activity( - MessageFactory.text("Sign-in failed. Max retries reached. Please try again later.") + MessageFactory.text( + "Sign-in failed. Max retries reached. Please try again later." + ) ) elif flow_state.is_expired(): await context.send_activity( @@ -641,25 +645,39 @@ async def _handle_flow_response(self, context: TurnContext, flow_response: FlowR logger.warning("Sign-in flow failed for unknown reasons.") await context.send_activity("Sign-in failed. Please try again.") - async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnState) -> bool: + async def _on_turn_auth_intercept( + self, context: TurnContext, turn_state: TurnState + ) -> bool: """Intercepts the turn to check for active authentication flows.""" - logger.debug("Checking for active sign-in flow for context: %s with activity type %s", context.activity.id, context.activity.type) + logger.debug( + "Checking for active sign-in flow for context: %s with activity type %s", + context.activity.id, + context.activity.type, + ) prev_flow_state = await self._auth.get_active_flow_state(context) if prev_flow_state: - logger.debug("Previous flow state: %s", { - "user_id": prev_flow_state.user_id, - "connection": prev_flow_state.connection, - "channel_id": prev_flow_state.channel_id, - "auth_handler_id": prev_flow_state.auth_handler_id, - "tag": prev_flow_state.tag, - "expiration": prev_flow_state.expiration, - }) + logger.debug( + "Previous flow state: %s", + { + "user_id": prev_flow_state.user_id, + "connection": prev_flow_state.connection, + "channel_id": prev_flow_state.channel_id, + "auth_handler_id": prev_flow_state.auth_handler_id, + "tag": prev_flow_state.tag, + "expiration": prev_flow_state.expiration, + }, + ) # proceed if there is an existing flow to continue # new flows should be initiated in _on_activity # this can be reorganized later... but it works for now - if (prev_flow_state and - (prev_flow_state.tag == FlowStateTag.NOT_STARTED or prev_flow_state.is_active()) - and context.activity.type in [ActivityTypes.message, ActivityTypes.invoke]): + if ( + prev_flow_state + and ( + prev_flow_state.tag == FlowStateTag.NOT_STARTED + or prev_flow_state.is_active() + ) + and context.activity.type in [ActivityTypes.message, ActivityTypes.invoke] + ): logger.debug("Sign-in flow is active for context: %s", context.activity.id) @@ -676,13 +694,11 @@ async def _on_turn_auth_intercept(self, context: TurnContext, turn_state: TurnSt if token_response: new_context = copy(context) new_context.activity = saved_activity - logger.info( - "Resending continuation activity %s", saved_activity.text - ) + logger.info("Resending continuation activity %s", saved_activity.text) await self.on_turn(new_context) await turn_state.save(context) - return True # early return from _on_turn - return False # continue _on_turn + return True # early return from _on_turn + return False # continue _on_turn async def on_turn(self, context: TurnContext): logger.debug( @@ -820,13 +836,23 @@ async def _on_activity(self, context: TurnContext, state: StateT): else: sign_in_complete = False for auth_handler_id in route.auth_handlers: - logger.debug("Beginning or continuing flow for auth handler %s", auth_handler_id) - flow_response: FlowResponse = await self._auth.begin_or_continue_flow( - context, state, auth_handler_id + logger.debug( + "Beginning or continuing flow for auth handler %s", + auth_handler_id, + ) + flow_response: FlowResponse = ( + await self._auth.begin_or_continue_flow( + context, state, auth_handler_id + ) ) await self._handle_flow_response(context, flow_response) - logger.debug("Flow response flow_state.tag: %s", flow_response.flow_state.tag) - sign_in_complete = flow_response.flow_state.tag == FlowStateTag.COMPLETE + logger.debug( + "Flow response flow_state.tag: %s", + flow_response.flow_state.tag, + ) + sign_in_complete = ( + flow_response.flow_state.tag == FlowStateTag.COMPLETE + ) if not sign_in_complete: break diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py index c964ae2f..7c962a43 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py @@ -1,10 +1,5 @@ -from .authorization import ( - Authorization -) -from .auth_handler import ( - AuthHandler, - AuthorizationHandlers -) +from .authorization import Authorization +from .auth_handler import AuthHandler, AuthorizationHandlers __all__ = [ "Authorization", diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py index b7afd9b1..ddde6e9a 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class AuthHandler: """ Interface defining an authorization handler for OAuth flows. @@ -42,5 +43,6 @@ def __init__( f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" ) + # # Type alias for authorization handlers dictionary AuthorizationHandlers = Dict[str, AuthHandler] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index 3cf2ffc7..145ac895 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -17,18 +17,13 @@ from microsoft.agents.hosting.core.connector.client import UserTokenClient from ...turn_context import TurnContext -from ...oauth import ( - OAuthFlow, - FlowResponse, - FlowState, - FlowStateTag, - FlowStorageClient -) +from ...oauth import OAuthFlow, FlowResponse, FlowState, FlowStateTag, FlowStorageClient from ..state.turn_state import TurnState from .auth_handler import AuthHandler logger = logging.getLogger(__name__) + class Authorization: """ Class responsible for managing authorization and OAuth flows. @@ -87,19 +82,17 @@ def _ids_from_context(self, context: TurnContext) -> tuple[str, str]: Raises a ValueError if channel ID or user ID are missing. """ if ( - not context.activity.channel_id or - not context.activity.from_property or - not context.activity.from_property.id + not context.activity.channel_id + or not context.activity.from_property + or not context.activity.from_property.id ): raise ValueError("Channel ID and User ID are required") - + return context.activity.channel_id, context.activity.from_property.id async def _load_flow( - self, - context: TurnContext, - auth_handler_id: str = "" - ) -> tuple[OAuthFlow, FlowStorageClient]: + self, context: TurnContext, auth_handler_id: str = "" + ) -> tuple[OAuthFlow, FlowStorageClient]: """Loads the OAuth flow for a specific auth handler. Args: @@ -111,15 +104,19 @@ async def _load_flow( chosen handler, and the channel and user info found in the context. The FlowStorageClient corresponds to the same channel and user info. """ - user_token_client: UserTokenClient = context.turn_state.get(context.adapter.USER_TOKEN_CLIENT_KEY) - + user_token_client: UserTokenClient = context.turn_state.get( + context.adapter.USER_TOKEN_CLIENT_KEY + ) + # resolve handler id auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) auth_handler_id = auth_handler.name channel_id, user_id = self._ids_from_context(context) - ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims["aud"] + ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims[ + "aud" + ] # try to load existing state flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) @@ -133,7 +130,7 @@ async def _load_flow( user_id=user_id, auth_handler_id=auth_handler_id, connection=auth_handler.abs_oauth_connection_name, - ms_app_id=ms_app_id + ms_app_id=ms_app_id, ) await flow_storage_client.write(flow_state) @@ -141,7 +138,9 @@ async def _load_flow( return flow, flow_storage_client @asynccontextmanager - async def open_flow(self, context: TurnContext, auth_handler_id: str = "") -> AsyncIterator[OAuthFlow]: + async def open_flow( + self, context: TurnContext, auth_handler_id: str = "" + ) -> AsyncIterator[OAuthFlow]: """Loads an OAuth flow and saves changes the changes to storage if any are made. Args: @@ -242,10 +241,10 @@ async def _handle_obo( """ auth_handler = self.resolve_handler(handler_id) - token_provider: AccessTokenProviderBase = self._connection_manager.get_connection( - auth_handler.obo_connection_name + token_provider: AccessTokenProviderBase = ( + self._connection_manager.get_connection(auth_handler.obo_connection_name) ) - + logger.info("Attempting to exchange token on behalf of user") new_token = await token_provider.aquire_token_on_behalf_of( scopes=scopes, @@ -282,7 +281,7 @@ async def begin_or_continue_flow( Returns: The token response from the OAuth provider. - + """ if not auth_handler_id: auth_handler_id = self.resolve_handler().name @@ -290,16 +289,28 @@ async def begin_or_continue_flow( logger.debug("Beginning or continuing OAuth flow") async with self.open_flow(context, auth_handler_id) as flow: prev_tag = flow.flow_state.tag - flow_response: FlowResponse = await flow.begin_or_continue_flow(context.activity) - + flow_response: FlowResponse = await flow.begin_or_continue_flow( + context.activity + ) + flow_state: FlowState = flow_response.flow_state - if flow_state.tag == FlowStateTag.COMPLETE and prev_tag != FlowStateTag.COMPLETE: + if ( + flow_state.tag == FlowStateTag.COMPLETE + and prev_tag != FlowStateTag.COMPLETE + ): logger.debug("Calling Authorization sign in success handler") - self._sign_in_success_handler(context, turn_state, flow_state.auth_handler_id) + self._sign_in_success_handler( + context, turn_state, flow_state.auth_handler_id + ) elif flow_state.tag == FlowStateTag.FAILURE: logger.debug("Calling Authorization sign in failure handler") - self._sign_in_failure_handler(context, turn_state, flow_state.auth_handler_id, flow_response.flow_error_tag) + self._sign_in_failure_handler( + context, + turn_state, + flow_state.auth_handler_id, + flow_response.flow_error_tag, + ) return flow_response @@ -320,14 +331,14 @@ def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: # Return the first handler if no ID specified return next(iter(self._auth_handlers.values())) - + async def _sign_out( self, context: TurnContext, auth_handler_ids: Iterable[str], ) -> None: """Signs out from the specified auth handlers. - + Args: context: The context object for the current turn. auth_handler_ids: Iterable of auth handler IDs to sign out from. @@ -384,4 +395,4 @@ def on_sign_in_failure( Args: handler: The handler function to call on sign-in failure. """ - self._sign_in_failure_handler = handler \ No newline at end of file + self._sign_in_failure_handler = handler diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py index c9730300..79858343 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py @@ -1,8 +1,4 @@ -from .flow_state import ( - FlowState, - FlowStateTag, - FlowErrorTag -) +from .flow_state import FlowState, FlowStateTag, FlowErrorTag from .flow_storage_client import FlowStorageClient from .oauth_flow import OAuthFlow, FlowResponse @@ -12,5 +8,5 @@ "FlowErrorTag", "FlowResponse", "FlowStorageClient", - "OAuthFlow" -] \ No newline at end of file + "OAuthFlow", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py index 012892aa..1ac105df 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -11,9 +11,10 @@ from ..storage import StoreItem + class FlowStateTag(Enum): """Represents the top-level state of an OAuthFlow - + For instance, a flow can arrive at an error, but its broader state may still be CONTINUE if the flow can still progress @@ -25,18 +26,21 @@ class FlowStateTag(Enum): FAILURE = "failure" COMPLETE = "complete" + class FlowErrorTag(Enum): """Represents the various error states that can occur during an OAuthFlow""" + NONE = "none" MAGIC_FORMAT = "magic_format" MAGIC_CODE_INCORRECT = "magic_code_incorrect" OTHER = "other" + class FlowState(BaseModel, StoreItem): """Represents the state of an OAuthFlow""" user_token: str = "" - + channel_id: str = "" user_id: str = "" ms_app_id: str = "" @@ -60,10 +64,18 @@ def is_expired(self) -> bool: def reached_max_attempts(self) -> bool: return self.attempts_remaining <= 0 - + def is_active(self) -> bool: - return not self.is_expired() and not self.reached_max_attempts() and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] - + return ( + not self.is_expired() + and not self.reached_max_attempts() + and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] + ) + def refresh(self): - if self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE, FlowStateTag.COMPLETE] and self.is_expired(): - self.tag = FlowStateTag.NOT_STARTED \ No newline at end of file + if ( + self.tag + in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE, FlowStateTag.COMPLETE] + and self.is_expired() + ): + self.tag = FlowStateTag.NOT_STARTED diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py index 74acc117..7ab03879 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -6,6 +6,7 @@ from ..storage import Storage from .flow_state import FlowState + class DummyCache(Storage): async def read(self, keys: list[str], **kwargs) -> dict[str, FlowState]: @@ -17,6 +18,7 @@ async def write(self, changes: dict[str, FlowState]) -> None: async def delete(self, keys: list[str]) -> None: pass + # this could be generalized. Ideas: # - CachedStorage class for two-tier storage # - Namespaced/PrefixedStorage class for namespacing keying @@ -32,7 +34,7 @@ def __init__( channel_id: str, user_id: str, storage: Storage, - cache_class: type[Storage] = None + cache_class: type[Storage] = None, ): """ Args: @@ -44,7 +46,9 @@ def __init__( """ if not user_id or not channel_id: - raise ValueError("FlowStorageClient.__init__(): channel_id and user_id must be set.") + raise ValueError( + "FlowStorageClient.__init__(): channel_id and user_id must be set." + ) self._base_key = f"auth/{channel_id}/{user_id}/" self._storage = storage @@ -86,4 +90,4 @@ async def delete(self, auth_handler_id: str) -> None: cached_state = await self._cache.read([key], target_cls=FlowState) if cached_state: await self._cache.delete([key]) - await self._storage.delete([key]) \ No newline at end of file + await self._storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py index 86016a87..b5f0d5bb 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -14,7 +14,7 @@ ActivityTypes, TokenExchangeState, TokenResponse, - SignInResource + SignInResource, ) from ..connector.client import UserTokenClient @@ -22,14 +22,17 @@ logger = logging.getLogger(__name__) + class FlowResponse(BaseModel): """Represents the response for a flow operation.""" + flow_state: FlowState = FlowState() flow_error_tag: FlowErrorTag = FlowErrorTag.NONE token_response: Optional[TokenResponse] = None sign_in_resource: Optional[SignInResource] = None continuation_activity: Optional[Activity] = None + class OAuthFlow: """ Manages the OAuth flow. @@ -45,10 +48,7 @@ class OAuthFlow: """ def __init__( - self, - flow_state: FlowState, - user_token_client: UserTokenClient, - **kwargs + self, flow_state: FlowState, user_token_client: UserTokenClient, **kwargs ): """ Arguments: @@ -61,18 +61,24 @@ def __init__( set when starting a flow (default: 3). """ if not flow_state or not user_token_client: - raise ValueError("OAuthFlow.__init__(): flow_state and user_token_client are required") - - if (not flow_state.connection or - not flow_state.ms_app_id or - not flow_state.channel_id or - not flow_state.user_id): - raise ValueError("OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined") - + raise ValueError( + "OAuthFlow.__init__(): flow_state and user_token_client are required" + ) + + if ( + not flow_state.connection + or not flow_state.ms_app_id + or not flow_state.channel_id + or not flow_state.user_id + ): + raise ValueError( + "OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined" + ) + logger.debug("Initializing OAuthFlow with flow state: %s", flow_state) - + self._flow_state = flow_state.model_copy() - + self._abs_oauth_connection_name = self._flow_state.connection self._ms_app_id = self._flow_state.ms_app_id self._channel_id = self._flow_state.channel_id @@ -80,20 +86,31 @@ def __init__( self._user_token_client = user_token_client - self._default_flow_duration = kwargs.get("default_flow_duration", 10 * 60) # default to 10 minutes - self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts - - logger.debug("OAuthFlow initialized with connection: %s, ms_app_id: %s, channel_id: %s, user_id: %s", - self._abs_oauth_connection_name, self._ms_app_id, self._channel_id, self._user_id) - logger.debug("Default flow duration: %d ms, Max attempts: %d", self._default_flow_duration, self._max_attempts) + self._default_flow_duration = kwargs.get( + "default_flow_duration", 10 * 60 + ) # default to 10 minutes + self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts + + logger.debug( + "OAuthFlow initialized with connection: %s, ms_app_id: %s, channel_id: %s, user_id: %s", + self._abs_oauth_connection_name, + self._ms_app_id, + self._channel_id, + self._user_id, + ) + logger.debug( + "Default flow duration: %d ms, Max attempts: %d", + self._default_flow_duration, + self._max_attempts, + ) @property def flow_state(self) -> FlowState: return self._flow_state.model_copy() - + async def get_user_token(self, magic_code: str = None) -> TokenResponse: """Get the user token based on the context. - + Args: magic_code (str, optional): Defaults to None. The magic code for user authentication. @@ -104,42 +121,57 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: Notes: flow_state.user_token is updated with the latest token. """ - logger.info("Getting user token for user_id: %s, connection: %s", self._user_id, self._abs_oauth_connection_name) - token_response: TokenResponse = await self._user_token_client.user_token.get_token( - user_id=self._user_id, - connection_name=self._abs_oauth_connection_name, - channel_id=self._channel_id, - code=magic_code + logger.info( + "Getting user token for user_id: %s, connection: %s", + self._user_id, + self._abs_oauth_connection_name, + ) + token_response: TokenResponse = ( + await self._user_token_client.user_token.get_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + code=magic_code, + ) ) if token_response: logger.info("User token obtained successfully: %s", token_response) self._flow_state.user_token = token_response.token - self._flow_state.expiration = datetime.now().timestamp() + self._default_flow_duration + self._flow_state.expiration = ( + datetime.now().timestamp() + self._default_flow_duration + ) return token_response - + async def sign_out(self) -> None: """Sign out the user. - + Sets the flow state tag to NOT_STARTED Resets the flow state user_token field """ - logger.info("Signing out user_id: %s from connection: %s", self._user_id, self._abs_oauth_connection_name) + logger.info( + "Signing out user_id: %s from connection: %s", + self._user_id, + self._abs_oauth_connection_name, + ) await self._user_token_client.user_token.sign_out( user_id=self._user_id, connection_name=self._abs_oauth_connection_name, - channel_id=self._channel_id + channel_id=self._channel_id, ) self._flow_state.user_token = "" self._flow_state.tag = FlowStateTag.NOT_STARTED - + def _use_attempt(self) -> None: """Decrements the remaining attempts for the flow, checking for failure.""" self._flow_state.attempts_remaining -= 1 if self._flow_state.attempts_remaining <= 0: self._flow_state.tag = FlowStateTag.FAILURE - logger.debug("Using an attempt for the OAuth flow. Attempts remaining after use: %d", self._flow_state.attempts_remaining) - + logger.debug( + "Using an attempt for the OAuth flow. Attempts remaining after use: %d", + self._flow_state.attempts_remaining, + ) + async def begin_flow(self, activity: Activity) -> FlowResponse: """Begins the OAuthFlow. @@ -155,13 +187,14 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: token_response = await self.get_user_token() if token_response: return FlowResponse( - flow_state=self._flow_state, - token_response=token_response + flow_state=self._flow_state, token_response=token_response ) - + logger.debug("Starting new OAuth flow") self._flow_state.tag = FlowStateTag.BEGIN - self._flow_state.expiration = datetime.now().timestamp() + self._default_flow_duration + self._flow_state.expiration = ( + datetime.now().timestamp() + self._default_flow_duration + ) self._flow_state.attempts_remaining = self._max_attempts self._flow_state.user_token = "" @@ -171,17 +204,24 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: connection_name=self._abs_oauth_connection_name, conversation=activity.get_conversation_reference(), relates_to=activity.relates_to, - ms_app_id=self._ms_app_id + ms_app_id=self._ms_app_id, ) - sign_in_resource = await self._user_token_client.agent_sign_in.get_sign_in_resource( - state=token_exchange_state.get_encoded_state()) + sign_in_resource = ( + await self._user_token_client.agent_sign_in.get_sign_in_resource( + state=token_exchange_state.get_encoded_state() + ) + ) logger.debug("Sign-in resource obtained successfully: %s", sign_in_resource) - return FlowResponse(flow_state=self._flow_state, sign_in_resource=sign_in_resource) - - async def _continue_from_message(self, activity: Activity) -> tuple[TokenResponse, FlowErrorTag]: + return FlowResponse( + flow_state=self._flow_state, sign_in_resource=sign_in_resource + ) + + async def _continue_from_message( + self, activity: Activity + ) -> tuple[TokenResponse, FlowErrorTag]: """Handles the continuation of the flow from a message activity.""" magic_code: str = activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: @@ -193,28 +233,32 @@ async def _continue_from_message(self, activity: Activity) -> tuple[TokenRespons return token_response, FlowErrorTag.MAGIC_CODE_INCORRECT else: return TokenResponse(), FlowErrorTag.MAGIC_FORMAT - - async def _continue_from_invoke_verify_state(self, activity: Activity) -> TokenResponse: + + async def _continue_from_invoke_verify_state( + self, activity: Activity + ) -> TokenResponse: """Handles the continuation of the flow from an invoke activity for verifying state.""" token_verify_state = activity.value magic_code: str = token_verify_state.get("state") token_response: TokenResponse = await self.get_user_token(magic_code) return token_response - - async def _continue_from_invoke_token_exchange(self, activity: Activity) -> TokenResponse: + + async def _continue_from_invoke_token_exchange( + self, activity: Activity + ) -> TokenResponse: """Handles the continuation of the flow from an invoke activity for token exchange.""" token_exchange_request = activity.value token_response = await self._user_token_client.user_token.exchange_token( user_id=self._user_id, connection_name=self._abs_oauth_connection_name, channel_id=self._channel_id, - body=token_exchange_request + body=token_exchange_request, ) return token_response async def continue_flow(self, activity: Activity) -> FlowResponse: """Continues the OAuth flow based on the incoming activity. - + Args: activity: The incoming activity to continue the flow with. @@ -223,20 +267,26 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: """ logger.debug("Continuing auth flow...") - + if not self._flow_state.is_active(): logger.debug("OAuth flow is not active, cannot continue") self._flow_state.tag = FlowStateTag.FAILURE - return FlowResponse(flow_state=self._flow_state.model_copy(), token_response=None) - + return FlowResponse( + flow_state=self._flow_state.model_copy(), token_response=None + ) + flow_error_tag = FlowErrorTag.NONE if activity.type == ActivityTypes.message: token_response, flow_error_tag = await self._continue_from_message(activity) - elif (activity.type == ActivityTypes.invoke - and activity.name == "signin/verifyState"): + elif ( + activity.type == ActivityTypes.invoke + and activity.name == "signin/verifyState" + ): token_response = await self._continue_from_invoke_verify_state(activity) - elif (activity.type == ActivityTypes.invoke - and activity.name == "signin/tokenExchange"): + elif ( + activity.type == ActivityTypes.invoke + and activity.name == "signin/tokenExchange" + ): token_response = await self._continue_from_invoke_token_exchange(activity) else: raise ValueError(f"Unknown activity type {activity.type}") @@ -250,20 +300,25 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: self._use_attempt() else: self._flow_state.tag = FlowStateTag.COMPLETE - self._flow_state.expiration = datetime.now().timestamp() + self._default_flow_duration + self._flow_state.expiration = ( + datetime.now().timestamp() + self._default_flow_duration + ) self._flow_state.user_token = token_response.token - logger.debug("OAuth flow completed successfully, got TokenResponse: %s", token_response) + logger.debug( + "OAuth flow completed successfully, got TokenResponse: %s", + token_response, + ) return FlowResponse( flow_state=self._flow_state.model_copy(), flow_error_tag=flow_error_tag, token_response=token_response, - continuation_activity=self._flow_state.continuation_activity + continuation_activity=self._flow_state.continuation_activity, ) async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: """Begins a new OAuth flow or continues an existing one based on the activity. - + Args: activity: The incoming activity to begin or continue the flow with. @@ -271,13 +326,16 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: A FlowResponse object containing the updated flow state and any token response. """ self._flow_state.refresh() - if self._flow_state.tag == FlowStateTag.COMPLETE: # robrandao: TODO -> test + if self._flow_state.tag == FlowStateTag.COMPLETE: # robrandao: TODO -> test logger.debug("OAuth flow has already been completed, nothing to do") - return FlowResponse(flow_state=self._flow_state.model_copy(), token_response=TokenResponse(token=self._flow_state.user_token)) - + return FlowResponse( + flow_state=self._flow_state.model_copy(), + token_response=TokenResponse(token=self._flow_state.user_token), + ) + if self._flow_state.is_active(): logger.debug("Active flow, continuing...") return await self.continue_flow(activity) logger.debug("No active flow, beginning new flow...") - return await self.begin_flow(activity) \ No newline at end of file + return await self.begin_flow(activity) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py index e095cbd6..280f63e7 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py @@ -129,6 +129,7 @@ async def equals(self, other) -> bool: for key in self._key_history: if key not in self._memory: if len(await other.read([key], target_cls=MockStoreItem)) > 0: + breakpoint() return False # key should not exist in other continue @@ -138,6 +139,7 @@ async def equals(self, other) -> bool: res = await other.read([key], target_cls=target_cls) if key not in res or item != res[key]: + breakpoint() return False return True diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index e820b40d..d80f2a9a 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -2,14 +2,13 @@ import jwt -from microsoft.agents.activity import ( - ActivityTypes, - TokenResponse -) +from microsoft.agents.activity import ActivityTypes, TokenResponse from microsoft.agents.hosting.core import MemoryStorage from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase -from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase +from microsoft.agents.hosting.core.connector.user_token_client_base import ( + UserTokenClientBase, +) from microsoft.agents.hosting.core.app.oauth import Authorization from microsoft.agents.hosting.core.oauth import ( @@ -17,24 +16,27 @@ FlowErrorTag, FlowStateTag, FlowResponse, - OAuthFlow + OAuthFlow, ) # test constants from .tools.testing_oauth import * from .tools.testing_authorization import ( TestingConnectionManager as MockConnectionManager, - create_test_auth_handler + create_test_auth_handler, ) + class TestUtils: - def create_context(self, - mocker, - channel_id="__channel_id", - user_id="__user_id", - user_token_client=None): - + def create_context( + self, + mocker, + channel_id="__channel_id", + user_id="__user_id", + user_token_client=None, + ): + if not user_token_client: user_token_client = self.create_mock_user_token_client(mocker) @@ -51,7 +53,7 @@ def create_context(self, "__agent_identity_key": agent_identity, } return turn_context - + def create_mock_user_token_client( self, mocker, @@ -64,11 +66,11 @@ def create_mock_user_token_client( ) mock_user_token_client_class.user_token.sign_out = mocker.AsyncMock() return mock_user_token_client_class - + @pytest.fixture def mock_user_token_client_class(self, mocker): return self.create_mock_user_token_client(mocker) - + def create_mock_oauth_flow_class(self, mocker, token_response): mock_oauth_flow_class = mocker.Mock(spec=OAuthFlow) # mock_oauth_flow_class.get_user_token = mocker.AsyncMock(return_value=token_response) @@ -86,9 +88,9 @@ def mock_oauth_flow_class(self, mocker): # mock_flow_class.get_user_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) # mock_flow_class.sign_out = mocker.AsyncMock() # mocker.patch.object(OAuthFlow, "get_user_token") - + # return mock_flow_class - + @pytest.fixture def turn_context(self, mocker): return self.create_context(mocker, "__channel_id", "__user_id", "__connection") @@ -99,7 +101,7 @@ def create_user_token_client(self, mocker, get_token_return=""): user_token_client.user_token = mocker.Mock(spec=UserTokenBase) user_token_client.user_token.get_token = mocker.AsyncMock() user_token_client.user_token.sign_out = mocker.AsyncMock() - + return_value = TokenResponse() if isinstance(get_token_return, TokenResponse): return_value = get_token_return @@ -108,20 +110,22 @@ def create_user_token_client(self, mocker, get_token_return=""): user_token_client.user_token.get_token.return_value = return_value return user_token_client - + @pytest.fixture def user_token_client(self, mocker): return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) - + @pytest.fixture def auth_handlers(self): handlers = {} for key in STORAGE_INIT_DATA().keys(): if key.startswith("auth/"): - auth_handler_name = key[key.rindex("/")+1:] - handlers[auth_handler_name] = create_test_auth_handler(auth_handler_name, True) + auth_handler_name = key[key.rindex("/") + 1 :] + handlers[auth_handler_name] = create_test_auth_handler( + auth_handler_name, True + ) return handlers - + @pytest.fixture def connection_manager(self): return MockConnectionManager() @@ -130,27 +134,40 @@ def connection_manager(self): def auth(self, connection_manager, storage, auth_handlers): return Authorization(storage, connection_manager, auth_handlers) + class TestAuthorizationUtils(TestUtils): def create_storage(self): return MemoryStorage(STORAGE_INIT_DATA()) - + @pytest.fixture def storage(self): return self.create_storage() - + @pytest.fixture def baseline_storage(self): return StorageBaseline(STORAGE_INIT_DATA()) - - def patch_flow(self, mocker, flow_response=None, token=None,): - mocker.patch.object(OAuthFlow, "get_user_token", return_value=TokenResponse(token=token)) + + def patch_flow( + self, + mocker, + flow_response=None, + token=None, + ): + mocker.patch.object( + OAuthFlow, "get_user_token", return_value=TokenResponse(token=token) + ) mocker.patch.object(OAuthFlow, "sign_out") - mocker.patch.object(OAuthFlow, "begin_or_continue_flow", return_value=flow_response) + mocker.patch.object( + OAuthFlow, "begin_or_continue_flow", return_value=flow_response + ) + class TestAuthorization(TestAuthorizationUtils): - def test_init_configuration_variants(self,storage, connection_manager, auth_handlers): + def test_init_configuration_variants( + self, storage, connection_manager, auth_handlers + ): """Test initialization of authorization with different configuration variants.""" AGENTAPPLICATION = { "USERAUTHORIZATION": { @@ -160,9 +177,10 @@ def test_init_configuration_variants(self,storage, connection_manager, auth_hand "title": handler.title, "text": handler.text, "abs_oauth_connection_name": handler.abs_oauth_connection_name, - "obo_connection_name": handler.obo_connection_name + "obo_connection_name": handler.obo_connection_name, } - } for handler_name, handler in auth_handlers.items() + } + for handler_name, handler in auth_handlers.items() } } } @@ -170,36 +188,33 @@ def test_init_configuration_variants(self,storage, connection_manager, auth_hand storage, connection_manager, auth_handlers=None, - AGENTAPPLICATION=AGENTAPPLICATION + AGENTAPPLICATION=AGENTAPPLICATION, ) auth_with_handlers_list = Authorization( - storage, - connection_manager, - auth_handlers=auth_handlers + storage, connection_manager, auth_handlers=auth_handlers ) for auth_handler_name in auth_handlers.keys(): auth_handler_a = auth_with_config_obj.resolve_handler(auth_handler_name) auth_handler_b = auth_with_handlers_list.resolve_handler(auth_handler_name) - + assert auth_handler_a.name == auth_handler_b.name assert auth_handler_a.title == auth_handler_b.title assert auth_handler_a.text == auth_handler_b.text - assert auth_handler_a.abs_oauth_connection_name == auth_handler_b.abs_oauth_connection_name - assert auth_handler_a.obo_connection_name == auth_handler_b.obo_connection_name + assert ( + auth_handler_a.abs_oauth_connection_name + == auth_handler_b.abs_oauth_connection_name + ) + assert ( + auth_handler_a.obo_connection_name == auth_handler_b.obo_connection_name + ) @pytest.mark.asyncio - @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", - [ - ["missing", "webchat", "Alice"], - ["handler", "teams", "Bob"] - ]) + @pytest.mark.parametrize( + "auth_handler_id, channel_id, user_id", + [["missing", "webchat", "Alice"], ["handler", "teams", "Bob"]], + ) async def test_open_flow_value_error( - self, - mocker, - auth, - auth_handler_id, - channel_id, - user_id + self, mocker, auth, auth_handler_id, channel_id, user_id ): """Test opening a flow with a missing auth handler.""" context = self.create_context(mocker, channel_id, user_id) @@ -208,12 +223,14 @@ async def test_open_flow_value_error( pass @pytest.mark.asyncio - @pytest.mark.parametrize("auth_handler_id, channel_id, user_id", + @pytest.mark.parametrize( + "auth_handler_id, channel_id, user_id", [ ["", "webchat", "Alice"], ["graph", "teams", "Bob"], - ["slack", "webchat", "Chuck"] - ]) + ["slack", "webchat", "Chuck"], + ], + ) async def test_open_flow_readonly( self, mocker, @@ -222,7 +239,7 @@ async def test_open_flow_readonly( auth_handlers, auth_handler_id, channel_id, - user_id + user_id, ): """Test opening a flow and not modifying it.""" # setup @@ -235,7 +252,9 @@ async def test_open_flow_readonly( expected_flow_state = flow.flow_state # verify - actual_flow_state = await flow_storage_client.read(auth.resolve_handler(auth_handler_id).name) + actual_flow_state = await flow_storage_client.read( + auth.resolve_handler(auth_handler_id).name + ) assert actual_flow_state == expected_flow_state @pytest.mark.asyncio @@ -245,17 +264,14 @@ async def test_open_flow_success_modified_complete_flow( storage, connection_manager, mock_user_token_client_class, - auth_handlers + auth_handlers, ): # setup channel_id = "teams" user_id = "Alice" auth_handler_id = "graph" - self.create_user_token_client( - mocker, - get_token_return=RES_TOKEN - ) + self.create_user_token_client(mocker, get_token_return=RES_TOKEN) context = self.create_context(mocker, channel_id, user_id) context.activity.type = ActivityTypes.message @@ -275,7 +291,9 @@ async def test_open_flow_success_modified_complete_flow( # verify actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expiration = res_flow_state.expiration # we won't check this for now + expected_flow_state.expiration = ( + res_flow_state.expiration + ) # we won't check this for now assert res_flow_state == expected_flow_state assert actual_flow_state == expected_flow_state @@ -310,7 +328,9 @@ async def test_open_flow_success_modified_failure( # verify actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expiration = actual_flow_state.expiration # we won't check this for now + expected_flow_state.expiration = ( + actual_flow_state.expiration + ) # we won't check this for now assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT assert res_flow_state == expected_flow_state @@ -318,11 +338,7 @@ async def test_open_flow_success_modified_failure( @pytest.mark.asyncio async def test_open_flow_success_modified_signout( - self, - mocker, - storage, - connection_manager, - auth_handlers + self, mocker, storage, connection_manager, auth_handlers ): # setup channel_id = "webchat" @@ -344,83 +360,76 @@ async def test_open_flow_success_modified_signout( # verify actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expiration = actual_flow_state.expiration # we won't check this for now + expected_flow_state.expiration = ( + actual_flow_state.expiration + ) # we won't check this for now assert actual_flow_state == expected_flow_state @pytest.mark.asyncio - async def test_get_token_success( - self, - mocker, - auth - ): + async def test_get_token_success(self, mocker, auth): mock_user_token_client_class = self.create_user_token_client( + mocker, get_token_return=TokenResponse(token="token") + ) + context = self.create_context( mocker, - get_token_return=TokenResponse(token="token") + "__channel_id", + "__user_id", + user_token_client=mock_user_token_client_class, ) - context = self.create_context(mocker, "__channel_id", "__user_id", user_token_client=mock_user_token_client_class) assert await auth.get_token(context, "slack") == TokenResponse(token="token") mock_user_token_client_class.user_token.get_token.assert_called_once() @pytest.mark.asyncio - async def test_get_token_empty_response( - self, - mocker, - auth - ): + async def test_get_token_empty_response(self, mocker, auth): mock_user_token_client_class = self.create_user_token_client( + mocker, get_token_return=TokenResponse() + ) + context = self.create_context( mocker, - get_token_return=TokenResponse() + "__channel_id", + "__user_id", + user_token_client=mock_user_token_client_class, ) - context = self.create_context(mocker, "__channel_id", "__user_id", user_token_client=mock_user_token_client_class) assert await auth.get_token(context, "graph") == TokenResponse() mock_user_token_client_class.user_token.get_token.assert_called_once() @pytest.mark.asyncio async def test_get_token_error( - self, - turn_context, - storage, - connection_manager, - auth_handlers + self, turn_context, storage, connection_manager, auth_handlers ): auth = Authorization(storage, connection_manager, auth_handlers) with pytest.raises(ValueError): await auth.get_token(turn_context, "missing-handler") @pytest.mark.asyncio - async def test_exchange_token_no_token( - self, - mocker, - turn_context, - auth - ): + async def test_exchange_token_no_token(self, mocker, turn_context, auth): self.create_mock_oauth_flow_class(mocker, TokenResponse()) res = await auth.exchange_token(turn_context, ["scope"], "github") assert res == TokenResponse() @pytest.mark.asyncio - async def test_exchange_token_not_exchangeable( - self, - mocker, - turn_context, - auth - ): + async def test_exchange_token_not_exchangeable(self, mocker, turn_context, auth): token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - self.create_mock_oauth_flow_class(mocker, TokenResponse(connection_name="github", token=token)) + self.create_mock_oauth_flow_class( + mocker, TokenResponse(connection_name="github", token=token) + ) res = await auth.exchange_token(turn_context, ["scope"], "github") assert res == TokenResponse() @pytest.mark.asyncio - async def test_exchange_token_valid_exchangeable( - self, - turn_context, - mocker, - auth - ): + async def test_exchange_token_valid_exchangeable(self, turn_context, mocker, auth): token = jwt.encode({"aud": "api://botframework.test.api"}, "") - self.create_mock_oauth_flow_class(mocker, TokenResponse(connection_name="github", token=token)) - mock_user_token_client_class = self.create_mock_user_token_client(mocker, token=token) - mock_user_token_client_class.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(scopes=["scope"], token=token, connection_name="github")) + self.create_mock_oauth_flow_class( + mocker, TokenResponse(connection_name="github", token=token) + ) + mock_user_token_client_class = self.create_mock_user_token_client( + mocker, token=token + ) + mock_user_token_client_class.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse( + scopes=["scope"], token=token, connection_name="github" + ) + ) res = await auth.exchange_token(turn_context, ["scope"], "github") assert res == TokenResponse(token="github-obo-connection-obo-token") @@ -428,7 +437,10 @@ async def test_exchange_token_valid_exchangeable( async def test_get_active_flow_state(self, mocker, auth): context = self.create_context(mocker, "webchat", "Alice") actual_flow_state = await auth.get_active_flow_state(context) - assert actual_flow_state == STORAGE_SAMPLE_DICT[flow_key("webchat", "Alice", "github")] + assert ( + actual_flow_state + == STORAGE_SAMPLE_DICT[flow_key("webchat", "Alice", "github")] + ) @pytest.mark.asyncio async def test_get_active_flow_state_missing(self, mocker, auth): @@ -437,22 +449,26 @@ async def test_get_active_flow_state_missing(self, mocker, auth): assert res is None @pytest.mark.asyncio - async def test_begin_or_continue_flow_success( - self, - mocker, - auth - ): + async def test_begin_or_continue_flow_success(self, mocker, auth): # robrandao: TODO -> lower priority -> more testing here # setup - mocker.patch.object(OAuthFlow, "begin_or_continue_flow", return_value=FlowResponse( - token_response=TokenResponse(token="token"), - flow_state=FlowState(tag=FlowStateTag.COMPLETE, auth_handler_id="github") - )) + mocker.patch.object( + OAuthFlow, + "begin_or_continue_flow", + return_value=FlowResponse( + token_response=TokenResponse(token="token"), + flow_state=FlowState( + tag=FlowStateTag.COMPLETE, auth_handler_id="github" + ), + ), + ) context = self.create_context(mocker, "webchat", "Alice") context.dummy_val = None + def on_sign_in_success(context, turn_state, auth_handler_id): context.dummy_val = auth_handler_id + def on_sign_in_failure(context, turn_state, auth_handler_id, err): context.dummy_val = str(err) @@ -464,18 +480,16 @@ def on_sign_in_failure(context, turn_state, auth_handler_id, err): assert flow_response.token_response == TokenResponse(token="token") @pytest.mark.asyncio - async def test_begin_or_continue_flow_already_completed( - self, - mocker, - auth - ): + async def test_begin_or_continue_flow_already_completed(self, mocker, auth): # robrandao: TODO -> lower priority -> more testing here # setup context = self.create_context(mocker, "webchat", "Alice") context.dummy_val = None + def on_sign_in_success(context, turn_state, auth_handler_id): context.dummy_val = auth_handler_id + def on_sign_in_failure(context, turn_state, auth_handler_id, err): context.dummy_val = str(err) @@ -489,23 +503,28 @@ def on_sign_in_failure(context, turn_state, auth_handler_id, err): @pytest.mark.asyncio async def test_begin_or_continue_flow_failure( - self, - mocker, - mock_oauth_flow_class, - auth - ): + self, mocker, mock_oauth_flow_class, auth + ): # robrandao: TODO -> lower priority -> more testing here # setup - mocker.patch.object(OAuthFlow, "begin_or_continue_flow", return_value=FlowResponse( - token_response=TokenResponse(token="token"), - flow_state=FlowState(tag=FlowStateTag.FAILURE, auth_handler_id="github"), - flow_state_error=FlowErrorTag.MAGIC_FORMAT - )) + mocker.patch.object( + OAuthFlow, + "begin_or_continue_flow", + return_value=FlowResponse( + token_response=TokenResponse(token="token"), + flow_state=FlowState( + tag=FlowStateTag.FAILURE, auth_handler_id="github" + ), + flow_state_error=FlowErrorTag.MAGIC_FORMAT, + ), + ) context = self.create_context(mocker, "webchat", "Alice") context.dummy_val = None + def on_sign_in_success(context, turn_state, auth_handler_id): context.dummy_val = auth_handler_id + def on_sign_in_failure(context, turn_state, auth_handler_id, err): context.dummy_val = str(err) @@ -536,7 +555,7 @@ async def test_sign_out_individual( storage, baseline_storage, connection_manager, - auth_handlers + auth_handlers, ): # setup storage_client = FlowStorageClient("teams", "Alice", storage) @@ -547,7 +566,10 @@ async def test_sign_out_individual( await auth.sign_out(context, "graph") # verify - assert await storage.read([storage_client.key("graph")], target_cls=FlowState) == {} + assert ( + await storage.read([storage_client.key("graph")], target_cls=FlowState) + == {} + ) OAuthFlow.sign_out.assert_called_once() @pytest.mark.asyncio @@ -560,7 +582,7 @@ async def test_sign_out_all( storage, baseline_storage, connection_manager, - auth_handlers + auth_handlers, ): # setup storage_client = FlowStorageClient("webchat", "Alice", storage) @@ -569,8 +591,17 @@ async def test_sign_out_all( context = self.create_context(mocker, "webchat", "Alice") await auth.sign_out(context) - # verify - assert await storage.read([storage_client.key("graph")], target_cls=FlowState) == {} - assert await storage.read([storage_client.key("github")], target_cls=FlowState) == {} - assert await storage.read([storage_client.key("slack")], target_cls=FlowState) == {} - OAuthFlow.sign_out.assert_called() # ignore red squiggly -> mocked + # verify + assert ( + await storage.read([storage_client.key("graph")], target_cls=FlowState) + == {} + ) + assert ( + await storage.read([storage_client.key("github")], target_cls=FlowState) + == {} + ) + assert ( + await storage.read([storage_client.key("slack")], target_cls=FlowState) + == {} + ) + OAuthFlow.sign_out.assert_called() # ignore red squiggly -> mocked diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py index d4f3c737..84c669e6 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py @@ -4,22 +4,53 @@ from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag + class TestFlowState: @pytest.mark.parametrize( "original_flow_state, refresh_to_not_started", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), - True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp()), - False), - ] + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=-1, + expiration=datetime.now().timestamp(), + ), + False, + ), + ], ) def test_refresh(self, original_flow_state, refresh_to_not_started): new_flow_state = original_flow_state.model_copy() @@ -32,17 +63,47 @@ def test_refresh(self, original_flow_state, refresh_to_not_started): @pytest.mark.parametrize( "flow_state, expected", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), - True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp()+1000), - False), - ] + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=-1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ], ) def test_is_expired(self, flow_state, expected): assert flow_state.is_expired() == expected @@ -50,17 +111,47 @@ def test_is_expired(self, flow_state, expected): @pytest.mark.parametrize( "flow_state, expected", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), - False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), - True), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()-100), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp()), - True), - ] + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() - 100, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=-1, + expiration=datetime.now().timestamp(), + ), + True, + ), + ], ) def test_reached_max_attempts(self, flow_state, expected): assert flow_state.reached_max_attempts() == expected @@ -68,25 +159,79 @@ def test_reached_max_attempts(self, flow_state, expected): @pytest.mark.parametrize( "flow_state, expected", [ - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp()), - False), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp()), - False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp()-100), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expiration=datetime.now().timestamp()-100), - False), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=2, expiration=datetime.now().timestamp()+1000), - True), - (FlowState(tag=FlowStateTag.BEGIN, attempts_remaining=0, expiration=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.COMPLETE, attempts_remaining=-1, expiration=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.FAILURE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), - False), - (FlowState(tag=FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp()+1000), - True) - ] + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=1, + expiration=datetime.now().timestamp() - 100, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + expiration=datetime.now().timestamp() + 1000, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=0, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=-1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + True, + ), + ], ) def test_is_active(self, flow_state, expected): - assert flow_state.is_active() == expected \ No newline at end of file + assert flow_state.is_active() == expected diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py index 925d88eb..caa65b29 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py @@ -4,20 +4,21 @@ from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem from microsoft.agents.hosting.core.oauth import FlowState, FlowStorageClient + class TestFlowStorageClient: @pytest.fixture def channel_id(self): return "__channel_id" - + @pytest.fixture def user_id(self): return "__user_id" - + @pytest.fixture def storage(self): return MemoryStorage() - + @pytest.fixture def client(self, channel_id, user_id, storage): return FlowStorageClient(channel_id, user_id, storage) @@ -50,15 +51,13 @@ async def test_init_fails_without_channel_id(self, user_id, storage): [ ("handler", "auth/__channel_id/__user_id/handler"), ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), - ] + ], ) def test_key(self, client, auth_handler_id, expected): assert client.key(auth_handler_id) == expected @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", ["handler", "auth_handler"] - ) + @pytest.mark.parametrize("auth_handler_id", ["handler", "auth_handler"]) async def test_read(self, mocker, user_id, channel_id, auth_handler_id): storage = mocker.AsyncMock() key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" @@ -66,7 +65,9 @@ async def test_read(self, mocker, user_id, channel_id, auth_handler_id): client = FlowStorageClient(channel_id, user_id, storage) res = await client.read(auth_handler_id) assert res is storage.read.return_value[key] - storage.read.assert_called_once_with([client.key(auth_handler_id)], target_cls=FlowState) + storage.read.assert_called_once_with( + [client.key(auth_handler_id)], target_cls=FlowState + ) @pytest.mark.asyncio async def test_read_missing(self, mocker): @@ -75,12 +76,12 @@ async def test_read_missing(self, mocker): client = FlowStorageClient("__channel_id", "__user_id", storage) res = await client.read("non_existent_handler") assert res is None - storage.read.assert_called_once_with([client.key("non_existent_handler")], target_cls=FlowState) + storage.read.assert_called_once_with( + [client.key("non_existent_handler")], target_cls=FlowState + ) @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", ["handler", "auth_handler"] - ) + @pytest.mark.parametrize("auth_handler_id", ["handler", "auth_handler"]) async def test_write(self, mocker, channel_id, user_id, auth_handler_id): storage = mocker.AsyncMock() storage.write.return_value = None @@ -88,12 +89,10 @@ async def test_write(self, mocker, channel_id, user_id, auth_handler_id): flow_state = mocker.Mock(spec=FlowState) flow_state.auth_handler_id = auth_handler_id await client.write(flow_state) - storage.write.assert_called_once_with({ client.key(auth_handler_id): flow_state }) + storage.write.assert_called_once_with({client.key(auth_handler_id): flow_state}) @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id", ["handler", "auth_handler"] - ) + @pytest.mark.parametrize("auth_handler_id", ["handler", "auth_handler"]) async def test_delete(self, mocker, channel_id, user_id, auth_handler_id): storage = mocker.AsyncMock() storage.delete.return_value = None @@ -105,18 +104,24 @@ async def test_delete(self, mocker, channel_id, user_id, auth_handler_id): async def test_integration_with_memory_storage(self, channel_id, user_id): flow_state_alpha = FlowState(auth_handler_id="handler", flow_started=True) - flow_state_beta = FlowState(auth_handler_id="auth_handler", flow_started=True, user_token="token") - - storage = MemoryStorage({ - "some_data": MockStoreItem({"value": "test"}), - f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, - f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta, - }) - baseline = MemoryStorage({ - "some_data": MockStoreItem({"value": "test"}), - f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, - f"fauth/{channel_id}/{user_id}/auth_handler": flow_state_beta, - }) + flow_state_beta = FlowState( + auth_handler_id="auth_handler", flow_started=True, user_token="token" + ) + + storage = MemoryStorage( + { + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + } + ) + baseline = MemoryStorage( + { + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"fauth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + } + ) # helpers async def read_check(*args, **kwargs): @@ -136,20 +141,30 @@ async def delete_both(*args, **kwargs): new_flow_state_alpha = FlowState(auth_handler_id="handler") flow_state_chi = FlowState(auth_handler_id="chi") - + await client.write(new_flow_state_alpha) await client.write(flow_state_chi) - await baseline.write({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) - await baseline.write({f"auth/{channel_id}/{user_id}/chi": flow_state_chi.model_copy()}) - - await write_both({f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()}) - await write_both({f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta.model_copy()}) + await baseline.write( + {f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()} + ) + await baseline.write( + {f"auth/{channel_id}/{user_id}/chi": flow_state_chi.model_copy()} + ) + + await write_both( + {f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()} + ) + await write_both( + {f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta.model_copy()} + ) await write_both({"other_data": MockStoreItem({"value": "more"})}) await delete_both(["some_data"]) await read_check([f"auth/{channel_id}/{user_id}/handler"], target_cls=FlowState) - await read_check([f"auth/{channel_id}/{user_id}/auth_handler"], target_cls=FlowState) + await read_check( + [f"auth/{channel_id}/{user_id}/auth_handler"], target_cls=FlowState + ) await read_check([f"auth/{channel_id}/{user_id}/chi"], target_cls=FlowState) await read_check(["other_data"], target_cls=MockStoreItem) await read_check(["some_data"], target_cls=MockStoreItem) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py index f4cd727d..c99d8886 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -13,14 +13,17 @@ OAuthFlow, FlowErrorTag, FlowStateTag, - FlowResponse + FlowResponse, ) from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase -from microsoft.agents.hosting.core.connector.user_token_client_base import UserTokenClientBase +from microsoft.agents.hosting.core.connector.user_token_client_base import ( + UserTokenClientBase, +) # test constants from .tools.testing_oauth import * + class TestOAuthFlowUtils: def create_user_token_client(self, mocker, get_token_return=None): @@ -29,23 +32,34 @@ def create_user_token_client(self, mocker, get_token_return=None): user_token_client.user_token = mocker.Mock(spec=UserTokenBase) user_token_client.user_token.get_token = mocker.AsyncMock() user_token_client.user_token.sign_out = mocker.AsyncMock() - + return_value = TokenResponse() if get_token_return: return_value = TokenResponse(token=get_token_return) user_token_client.user_token.get_token.return_value = return_value return user_token_client - + @pytest.fixture def user_token_client(self, mocker): return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) - def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", value=None, text="a"): + def create_activity( + self, + mocker, + activity_type=ActivityTypes.message, + name="a", + value=None, + text="a", + ): # def conv_ref(): # return mocker.MagicMock(spec=ConversationReference) mock_conversation_ref = mocker.MagicMock(ConversationReference) - mocker.patch.object(Activity, "get_conversation_reference", return_value=mocker.MagicMock(ConversationReference)) + mocker.patch.object( + Activity, + "get_conversation_reference", + return_value=mocker.MagicMock(ConversationReference), + ) # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) return Activity( type=activity_type, @@ -55,7 +69,7 @@ def create_activity(self, mocker, activity_type=ActivityTypes.message, name="a", # get_conversation_reference=mocker.Mock(return_value=conv_ref), relates_to=mocker.MagicMock(ConversationReference), value=value, - text=text + text=text, ) @pytest.fixture(params=FLOW_STATES.ALL()) @@ -69,8 +83,14 @@ def sample_failed_flow_state(self, request): @pytest.fixture(params=FLOW_STATES.INACTIVE()) def sample_inactive_flow_state(self, request): return request.param.model_copy() - - @pytest.fixture(params=[ flow_state for flow_state in FLOW_STATES.INACTIVE() if flow_state.tag != FlowStateTag.COMPLETE]) + + @pytest.fixture( + params=[ + flow_state + for flow_state in FLOW_STATES.INACTIVE() + if flow_state.tag != FlowStateTag.COMPLETE + ] + ) def sample_inactive_flow_state_not_completed(self, request): return request.param.model_copy() @@ -81,7 +101,7 @@ def sample_active_flow_state(self, request): @pytest.fixture def flow(self, sample_flow_state, user_token_client): return OAuthFlow(sample_flow_state, user_token_client) - + class TestOAuthFlow(TestOAuthFlowUtils): @@ -89,12 +109,9 @@ def test_init_no_user_token_client(self, sample_flow_state): with pytest.raises(ValueError): OAuthFlow(sample_flow_state, None) - @pytest.mark.parametrize("missing_value", [ - "connection", - "ms_app_id", - "channel_id", - "user_id" - ]) + @pytest.mark.parametrize( + "missing_value", ["connection", "ms_app_id", "channel_id", "user_id"] + ) def test_init_errors(self, missing_value, user_token_client): flow_state = FLOW_STATES.STARTED_FLOW.model_copy() flow_state.__setattr__(missing_value, None) @@ -110,7 +127,7 @@ def test_init_with_state(self, sample_flow_state, user_token_client): def test_flow_state_prop_copy(self, flow): flow_state = flow.flow_state - flow_state.user_id = (flow_state.user_id + "_copy") + flow_state.user_id = flow_state.user_id + "_copy" assert flow.flow_state.user_id == USER_ID assert flow_state.user_id == f"{USER_ID}_copy" @@ -124,7 +141,7 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client # test token_response = await flow.get_user_token() token = token_response.token - + # verify assert token == RES_TOKEN expected_final_flow_state.expiration = flow.flow_state.expiration @@ -133,9 +150,9 @@ async def test_get_user_token_success(self, sample_flow_state, user_token_client user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - code=None + code=None, ) - + @pytest.mark.asyncio async def test_get_user_token_failure(self, mocker, sample_flow_state): # setup @@ -145,7 +162,7 @@ async def test_get_user_token_failure(self, mocker, sample_flow_state): # test token_response = await flow.get_user_token() - + # verify assert token_response == TokenResponse() assert flow.flow_state == expected_final_flow_state @@ -153,7 +170,7 @@ async def test_get_user_token_failure(self, mocker, sample_flow_state): user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - code=None + code=None, ) @pytest.mark.asyncio @@ -171,12 +188,14 @@ async def test_sign_out(self, sample_flow_state, user_token_client): user_token_client.user_token.sign_out.assert_called_once_with( user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, - channel_id=CHANNEL_ID + channel_id=CHANNEL_ID, ) assert flow.flow_state == expected_flow_state @pytest.mark.asyncio - async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_client): + async def test_begin_flow_easy_case( + self, mocker, sample_flow_state, user_token_client + ): # setup flow = OAuthFlow(sample_flow_state, user_token_client) activity = mocker.Mock(spec=Activity) @@ -199,26 +218,34 @@ async def test_begin_flow_easy_case(self, mocker, sample_flow_state, user_token_ user_id=USER_ID, connection_name=ABS_OAUTH_CONNECTION_NAME, channel_id=CHANNEL_ID, - code=None + code=None, ) @pytest.mark.asyncio - async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_client): + async def test_begin_flow_long_case( + self, mocker, sample_flow_state, user_token_client + ): # mock # tes = mocker.Mock(TokenExchangeState) # tes.get_encoded_state = mocker.Mock(return_value="encoded_state") - mocker.patch.object(TokenExchangeState, "get_encoded_state", return_value="encoded_state") + mocker.patch.object( + TokenExchangeState, "get_encoded_state", return_value="encoded_state" + ) dummy_sign_in_resource = SignInResource( sign_in_link="https://example.com/signin", token_exchange_state=mocker.Mock( - TokenExchangeState, get_encoded_state=mocker.Mock(return_value="encoded_state") - ) + TokenExchangeState, + get_encoded_state=mocker.Mock(return_value="encoded_state"), + ), + ) + user_token_client.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() ) - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock( - return_value=dummy_sign_in_resource) + return_value=dummy_sign_in_resource + ) activity = self.create_activity(mocker) - + # setup flow = OAuthFlow(sample_flow_state, user_token_client) expected_flow_state = sample_flow_state @@ -232,7 +259,9 @@ async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_ # verify flow_state flow_state = flow.flow_state - expected_flow_state.expiration = flow_state.expiration # robrandao: TODO -> ignore this for now + expected_flow_state.expiration = ( + flow_state.expiration + ) # robrandao: TODO -> ignore this for now assert flow_state == response.flow_state assert flow_state == expected_flow_state @@ -243,7 +272,9 @@ async def test_begin_flow_long_case(self, mocker, sample_flow_state, user_token_ # robrandao: TODO more assertions on sign_in_resource @pytest.mark.asyncio - async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state, user_token_client): + async def test_continue_flow_not_active( + self, mocker, sample_inactive_flow_state, user_token_client + ): # setup activity = mocker.Mock() flow = OAuthFlow(sample_inactive_flow_state, user_token_client) @@ -259,12 +290,20 @@ async def test_continue_flow_not_active(self, mocker, sample_inactive_flow_state assert flow_response.flow_state == flow_state assert not flow_response.token_response - async def helper_continue_flow_failure(self, active_flow_state, user_token_client, activity, flow_error_tag): + async def helper_continue_flow_failure( + self, active_flow_state, user_token_client, activity, flow_error_tag + ): # setup flow = OAuthFlow(active_flow_state, user_token_client) expected_flow_state = active_flow_state - expected_flow_state.tag = FlowStateTag.CONTINUE if active_flow_state.attempts_remaining > 1 else FlowStateTag.FAILURE - expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining - 1 + expected_flow_state.tag = ( + FlowStateTag.CONTINUE + if active_flow_state.attempts_remaining > 1 + else FlowStateTag.FAILURE + ) + expected_flow_state.attempts_remaining = ( + active_flow_state.attempts_remaining - 1 + ) # test flow_response = await flow.continue_flow(activity) @@ -276,7 +315,9 @@ async def helper_continue_flow_failure(self, active_flow_state, user_token_clien assert flow_response.token_response == TokenResponse() assert flow_response.flow_error_tag == flow_error_tag - async def helper_continue_flow_success(self, active_flow_state, user_token_client, activity): + async def helper_continue_flow_success( + self, active_flow_state, user_token_client, activity + ): # setup flow = OAuthFlow(active_flow_state, user_token_client) expected_flow_state = active_flow_state @@ -287,7 +328,9 @@ async def helper_continue_flow_success(self, active_flow_state, user_token_clien # test flow_response = await flow.continue_flow(activity) flow_state = flow.flow_state - expected_flow_state.expiration = flow_state.expiration # robrandao: TODO -> ignore this for now + expected_flow_state.expiration = ( + flow_state.expiration + ) # robrandao: TODO -> ignore this for now # verify assert flow_response.flow_state == flow_state @@ -297,121 +340,184 @@ async def helper_continue_flow_success(self, active_flow_state, user_token_clien @pytest.mark.asyncio @pytest.mark.parametrize("magic_code", ["magic", "123", "", "1239453"]) - async def test_continue_flow_active_message_magic_format_error(self, mocker, sample_active_flow_state, user_token_client, magic_code): + async def test_continue_flow_active_message_magic_format_error( + self, mocker, sample_active_flow_state, user_token_client, magic_code + ): # setup activity = self.create_activity(mocker, ActivityTypes.message, text=magic_code) - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_FORMAT) + await self.helper_continue_flow_failure( + sample_active_flow_state, + user_token_client, + activity, + FlowErrorTag.MAGIC_FORMAT, + ) user_token_client.assert_not_called() @pytest.mark.asyncio - async def test_continue_flow_active_message_magic_code_error(self, mocker, sample_active_flow_state, user_token_client): + async def test_continue_flow_active_message_magic_code_error( + self, mocker, sample_active_flow_state, user_token_client + ): # setup - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) + user_token_client.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() + ) activity = self.create_activity(mocker, ActivityTypes.message, text="123456") - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.MAGIC_CODE_INCORRECT) + await self.helper_continue_flow_failure( + sample_active_flow_state, + user_token_client, + activity, + FlowErrorTag.MAGIC_CODE_INCORRECT, + ) user_token_client.user_token.get_token.assert_called_once_with( user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - code="123456" + code="123456", ) @pytest.mark.asyncio - async def test_continue_flow_active_message_success(self, mocker, sample_active_flow_state, user_token_client): + async def test_continue_flow_active_message_success( + self, mocker, sample_active_flow_state, user_token_client + ): # setup activity = self.create_activity(mocker, ActivityTypes.message, text="123456") - await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + await self.helper_continue_flow_success( + sample_active_flow_state, user_token_client, activity + ) user_token_client.user_token.get_token.assert_called_once_with( user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - code="123456" + code="123456", ) @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_verify_state_error(self, mocker, sample_active_flow_state, user_token_client): + async def test_continue_flow_active_sign_in_verify_state_error( + self, mocker, sample_active_flow_state, user_token_client + ): # setup - user_token_client.user_token.get_token = mocker.AsyncMock(return_value=TokenResponse()) - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ - "state": "magic_code" - }) - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) + user_token_client.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/verifyState", + value={"state": "magic_code"}, + ) + await self.helper_continue_flow_failure( + sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER + ) user_token_client.user_token.get_token.assert_called_once_with( user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - code="magic_code" + code="magic_code", ) @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_verify_success(self, mocker, sample_active_flow_state, user_token_client): - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/verifyState", value={ - "state": "magic_code" - }) - await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + async def test_continue_flow_active_sign_in_verify_success( + self, mocker, sample_active_flow_state, user_token_client + ): + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/verifyState", + value={"state": "magic_code"}, + ) + await self.helper_continue_flow_success( + sample_active_flow_state, user_token_client, activity + ) user_token_client.user_token.get_token.assert_called_once_with( user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - code="magic_code" + code="magic_code", ) @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_token_exchange_error(self, mocker, sample_active_flow_state, user_token_client): + async def test_continue_flow_active_sign_in_token_exchange_error( + self, mocker, sample_active_flow_state, user_token_client + ): token_exchange_request = {} - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse()) - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) - await self.helper_continue_flow_failure(sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER) + user_token_client.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/tokenExchange", + value=token_exchange_request, + ) + await self.helper_continue_flow_failure( + sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER + ) user_token_client.user_token.exchange_token.assert_called_once_with( user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - body=token_exchange_request + body=token_exchange_request, ) @pytest.mark.asyncio - async def test_continue_flow_active_sign_in_token_exchange_success(self, mocker, sample_active_flow_state, user_token_client): + async def test_continue_flow_active_sign_in_token_exchange_success( + self, mocker, sample_active_flow_state, user_token_client + ): token_exchange_request = {} - user_token_client.user_token.exchange_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) - activity = self.create_activity(mocker, ActivityTypes.invoke, name="signin/tokenExchange", value=token_exchange_request) - await self.helper_continue_flow_success(sample_active_flow_state, user_token_client, activity) + user_token_client.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse(token=RES_TOKEN) + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/tokenExchange", + value=token_exchange_request, + ) + await self.helper_continue_flow_success( + sample_active_flow_state, user_token_client, activity + ) user_token_client.user_token.exchange_token.assert_called_once_with( user_id=sample_active_flow_state.user_id, connection_name=sample_active_flow_state.connection, channel_id=sample_active_flow_state.channel_id, - body=token_exchange_request + body=token_exchange_request, ) @pytest.mark.asyncio - async def test_continue_flow_invalid_invoke_name(self, mocker, sample_active_flow_state, user_token_client): - with pytest.raises(ValueError): - activity = self.create_activity(mocker, ActivityTypes.invoke, name="other", value={}) + async def test_continue_flow_invalid_invoke_name( + self, mocker, sample_active_flow_state, user_token_client + ): + with pytest.raises(ValueError): + activity = self.create_activity( + mocker, ActivityTypes.invoke, name="other", value={} + ) flow = OAuthFlow(sample_active_flow_state, user_token_client) await flow.continue_flow(activity) @pytest.mark.asyncio - async def test_continue_flow_invalid_activity_type(self, mocker, sample_active_flow_state, user_token_client): - with pytest.raises(ValueError): - activity = self.create_activity(mocker, ActivityTypes.command, name="other", value={}) + async def test_continue_flow_invalid_activity_type( + self, mocker, sample_active_flow_state, user_token_client + ): + with pytest.raises(ValueError): + activity = self.create_activity( + mocker, ActivityTypes.command, name="other", value={} + ) flow = OAuthFlow(sample_active_flow_state, user_token_client) await flow.continue_flow(activity) @pytest.mark.asyncio - async def test_begin_or_continue_flow_not_started_flow( - self, - mocker - ): + async def test_begin_or_continue_flow_not_started_flow(self, mocker): # setup sample_flow_state = FLOW_STATES.NOT_STARTED_FLOW.model_copy() expected_response = FlowResponse( - flow_state = sample_flow_state, - token_response = TokenResponse(token=sample_flow_state.user_token), + flow_state=sample_flow_state, + token_response=TokenResponse(token=sample_flow_state.user_token), ) mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) activity_mock = mocker.Mock() flow = OAuthFlow(sample_flow_state, mocker.Mock()) - + # test actual_response = await flow.begin_or_continue_flow(activity_mock) @@ -427,8 +533,8 @@ async def test_begin_or_continue_flow_inactive_flow( ): # setup expected_response = FlowResponse( - flow_state = sample_inactive_flow_state_not_completed, - token_response = TokenResponse(), + flow_state=sample_inactive_flow_state_not_completed, + token_response=TokenResponse(), ) mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) @@ -450,13 +556,13 @@ async def test_begin_or_continue_flow_active_flow( ): # setup expected_response = FlowResponse( - flow_state = sample_active_flow_state, - token_response = TokenResponse(token=sample_active_flow_state.user_token), + flow_state=sample_active_flow_state, + token_response=TokenResponse(token=sample_active_flow_state.user_token), ) mocker.patch.object(OAuthFlow, "continue_flow", return_value=expected_response) flow = OAuthFlow(sample_active_flow_state, mocker.Mock()) - + # test activity_mock = mocker.Mock() actual_response = await flow.begin_or_continue_flow(activity_mock) @@ -466,10 +572,7 @@ async def test_begin_or_continue_flow_active_flow( OAuthFlow.continue_flow.assert_called_once_with(activity_mock) @pytest.mark.asyncio - async def test_begin_or_continue_flow_stale_flow_state( - self, - mocker - ): + async def test_begin_or_continue_flow_stale_flow_state(self, mocker): flow_state = FLOW_STATES.ACTIVE_EXP_FLOW.model_copy() expected_response = FlowResponse() @@ -482,14 +585,11 @@ async def test_begin_or_continue_flow_stale_flow_state( OAuthFlow.begin_flow.assert_called_once_with(None) @pytest.mark.asyncio - async def test_begin_or_continue_flow_completed_flow_state( - self, - mocker - ): + async def test_begin_or_continue_flow_completed_flow_state(self, mocker): flow_state = FLOW_STATES.COMPLETED_FLOW.model_copy() expected_response = FlowResponse( - flow_state = flow_state, - token_response = TokenResponse(token=flow_state.user_token) + flow_state=flow_state, + token_response=TokenResponse(token=flow_state.user_token), ) mocker.patch.object(OAuthFlow, "begin_flow", return_value=None) mocker.patch.object(OAuthFlow, "continue_flow", return_value=None) @@ -499,4 +599,4 @@ async def test_begin_or_continue_flow_completed_flow_state( assert actual_response == expected_response OAuthFlow.begin_flow.assert_not_called() - OAuthFlow.continue_flow.assert_not_called() \ No newline at end of file + OAuthFlow.continue_flow.assert_not_called() diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py index fb0fa9b0..e79486b5 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py @@ -27,6 +27,7 @@ from .mock_user_token_client import MockUserTokenClient + class TestingAdapter(ChannelAdapter): """ A mock adapter that can be used for unit testing of agent logic. diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py index ac184d46..9340f995 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py @@ -207,7 +207,7 @@ def __init__( storage=storage, auth_handlers=auth_handlers, connection_manager=connection_manager, - service_url="a" + service_url="a", ) # Configure each auth handler with mock OAuth flow behavior diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py index 9a9142d3..1dd8d799 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py @@ -13,145 +13,168 @@ "ms_app_id": MS_APP_ID, "channel_id": CHANNEL_ID, "user_id": USER_ID, - "connection": ABS_OAUTH_CONNECTION_NAME + "connection": ABS_OAUTH_CONNECTION_NAME, } + class FLOW_STATES: NOT_STARTED_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.NOT_STARTED, - attempts_remaining=1, - user_token="____", - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.NOT_STARTED, + attempts_remaining=1, + user_token="____", + expiration=datetime.now().timestamp() + 1000000, + ) STARTED_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.BEGIN, - attempts_remaining=1, - user_token="____", - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + user_token="____", + expiration=datetime.now().timestamp() + 1000000, + ) STARTED_FLOW_ONE_RETRY = FlowState( - **DEF_ARGS, - tag=FlowStateTag.BEGIN, - attempts_remaining=2, - user_token="____", - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=2, + user_token="____", + expiration=datetime.now().timestamp() + 1000000, + ) ACTIVE_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.CONTINUE, - attempts_remaining=2, - user_token="__token", - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expiration=datetime.now().timestamp() + 1000000, + ) ACTIVE_FLOW_ONE_RETRY = FlowState( - **DEF_ARGS, - tag=FlowStateTag.CONTINUE, - attempts_remaining=1, - user_token="__token", - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + user_token="__token", + expiration=datetime.now().timestamp() + 1000000, + ) ACTIVE_EXP_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.CONTINUE, - attempts_remaining=2, - user_token="__token", - expiration=datetime.now().timestamp() - ) + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expiration=datetime.now().timestamp(), + ) COMPLETED_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.COMPLETE, - attempts_remaining=2, - user_token="test_token", - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.COMPLETE, + attempts_remaining=2, + user_token="test_token", + expiration=datetime.now().timestamp() + 1000000, + ) FAIL_BY_ATTEMPTS_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.FAILURE, - attempts_remaining=0, - expiration=datetime.now().timestamp() + 1000000 - ) + **DEF_ARGS, + tag=FlowStateTag.FAILURE, + attempts_remaining=0, + expiration=datetime.now().timestamp() + 1000000, + ) FAIL_BY_EXP_FLOW = FlowState( - **DEF_ARGS, - tag=FlowStateTag.FAILURE, - attempts_remaining=2, - expiration=0 - ) + **DEF_ARGS, tag=FlowStateTag.FAILURE, attempts_remaining=2, expiration=0 + ) @staticmethod def clone_state_list(lst): - return [ flow_state.model_copy() for flow_state in lst ] + return [flow_state.model_copy() for flow_state in lst] @staticmethod def ALL(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.STARTED_FLOW, - FLOW_STATES.STARTED_FLOW_ONE_RETRY, - FLOW_STATES.ACTIVE_FLOW, - FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, - FLOW_STATES.ACTIVE_EXP_FLOW, - FLOW_STATES.COMPLETED_FLOW, - FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, - FLOW_STATES.FAIL_BY_EXP_FLOW - ]) + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW, + ] + ) @staticmethod def FAILED(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.ACTIVE_EXP_FLOW, - FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, - FLOW_STATES.FAIL_BY_EXP_FLOW - ]) + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW, + ] + ) @staticmethod def ACTIVE(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.STARTED_FLOW, - FLOW_STATES.STARTED_FLOW_ONE_RETRY, - FLOW_STATES.ACTIVE_FLOW, - FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, - ]) + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + ] + ) @staticmethod def INACTIVE(): - return FLOW_STATES.clone_state_list([ - FLOW_STATES.ACTIVE_EXP_FLOW, - FLOW_STATES.COMPLETED_FLOW, - FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, - FLOW_STATES.FAIL_BY_EXP_FLOW - ]) + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW, + ] + ) + def flow_key(channel_id, user_id, handler_id): return f"auth/{channel_id}/{user_id}/{handler_id}" + def update_flow_state_handler(flow_state, handler): flow_state = flow_state.model_copy() flow_state.auth_handler_id = handler return flow_state - + + STORAGE_SAMPLE_DICT = { "user_id": MockStoreItem({"id": "123"}), "session_id": MockStoreItem({"id": "abc"}), - flow_key("webchat", "Alice", "graph"): update_flow_state_handler(FLOW_STATES.COMPLETED_FLOW.model_copy(), "graph"), - flow_key("webchat", "Alice", "github"): update_flow_state_handler(FLOW_STATES.ACTIVE_FLOW_ONE_RETRY.model_copy(), "github"), - flow_key("teams", "Alice", "graph"): update_flow_state_handler(FLOW_STATES.STARTED_FLOW.model_copy(), "graph"), - flow_key("webchat", "Bob", "graph"): update_flow_state_handler(FLOW_STATES.ACTIVE_EXP_FLOW.model_copy(), "graph"), - flow_key("teams", "Bob", "slack"): update_flow_state_handler(FLOW_STATES.STARTED_FLOW.model_copy(), "slack"), - flow_key("webchat", "Chuck", "github"): update_flow_state_handler(FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW.model_copy(), "github"), + flow_key("webchat", "Alice", "graph"): update_flow_state_handler( + FLOW_STATES.COMPLETED_FLOW.model_copy(), "graph" + ), + flow_key("webchat", "Alice", "github"): update_flow_state_handler( + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY.model_copy(), "github" + ), + flow_key("teams", "Alice", "graph"): update_flow_state_handler( + FLOW_STATES.STARTED_FLOW.model_copy(), "graph" + ), + flow_key("webchat", "Bob", "graph"): update_flow_state_handler( + FLOW_STATES.ACTIVE_EXP_FLOW.model_copy(), "graph" + ), + flow_key("teams", "Bob", "slack"): update_flow_state_handler( + FLOW_STATES.STARTED_FLOW.model_copy(), "slack" + ), + flow_key("webchat", "Chuck", "github"): update_flow_state_handler( + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW.model_copy(), "github" + ), } + def STORAGE_INIT_DATA(): data = STORAGE_SAMPLE_DICT.copy() for key, value in data.items(): data[key] = value.model_copy() if isinstance(value, FlowState) else value return data + def update_data_with_flow_state(data, channel_id, user_id, auth_handler_id, flow_state): data = data.copy() key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" data[key] = flow_state.model_copy() - return data \ No newline at end of file + return data From b38212bf305320fcf1233e34bbca5036eedd219a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Brand=C3=A3o?= Date: Fri, 22 Aug 2025 15:09:18 -0700 Subject: [PATCH 32/32] Added pytest-mock as a dependency --- .azdo/ci-pr.yaml | 2 +- .github/workflows/python-package.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.azdo/ci-pr.yaml b/.azdo/ci-pr.yaml index 01189e5c..a1bf1c4a 100644 --- a/.azdo/ci-pr.yaml +++ b/.azdo/ci-pr.yaml @@ -26,7 +26,7 @@ steps: - script: | python -m pip install --upgrade pip - python -m pip install flake8 pytest black pytest-asyncio build setuptools-git-versioning + python -m pip install flake8 pytest pytest-mock black pytest-asyncio build setuptools-git-versioning if [ -f requirements.txt ]; then pip install -r requirements.txt; fi displayName: 'Install dependencies' diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index f50d8d14..ac26ed78 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest black pytest-asyncio build setuptools-git-versioning + python -m pip install flake8 pytest pytest-mock black pytest-asyncio build setuptools-git-versioning if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Check format with black run: |