diff --git a/.azdo/ci-pr.yaml b/.azdo/ci-pr.yaml index 01189e5c..a1bf1c4a 100644 --- a/.azdo/ci-pr.yaml +++ b/.azdo/ci-pr.yaml @@ -26,7 +26,7 @@ steps: - script: | python -m pip install --upgrade pip - python -m pip install flake8 pytest black pytest-asyncio build setuptools-git-versioning + python -m pip install flake8 pytest pytest-mock black pytest-asyncio build setuptools-git-versioning if [ -f requirements.txt ]; then pip install -r requirements.txt; fi displayName: 'Install dependencies' diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index f50d8d14..ac26ed78 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest black pytest-asyncio build setuptools-git-versioning + python -m pip install flake8 pytest pytest-mock black pytest-asyncio build setuptools-git-versioning if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Check format with black run: | diff --git a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py index 682c534b..00d6aa91 100644 --- a/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py +++ b/libraries/microsoft-agents-activity/microsoft/agents/activity/token_response.py @@ -23,3 +23,6 @@ class TokenResponse(AgentsModel): token: NonEmptyString = None expiration: NonEmptyString = None channel_id: NonEmptyString = None + + def __bool__(self): + return bool(self.token) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py index 5d95f4e6..f5d07cef 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/__init__.py @@ -1,6 +1,5 @@ from .activity_handler import ActivityHandler from .agent import Agent -from .oauth_flow import OAuthFlow from .card_factory import CardFactory from .channel_adapter import ChannelAdapter from .channel_api_handler_protocol import ChannelApiHandlerProtocol @@ -20,12 +19,11 @@ from .app.route import Route, RouteHandler from .app.typing_indicator import TypingIndicator -# OAuth -from .app.oauth.authorization import ( +# App Auth +from .app.oauth import ( Authorization, AuthorizationHandlers, AuthHandler, - SignInState, ) # App State @@ -44,6 +42,16 @@ from .authorization.jwt_token_validator import JwtTokenValidator from .authorization.auth_types import AuthTypes +# OAuth +from .oauth import ( + FlowState, + FlowStateTag, + FlowErrorTag, + FlowResponse, + FlowStorageClient, + OAuthFlow, +) + # Client API from .client.agent_conversation_reference import AgentConversationReference from .client.channel_factory_protocol import ChannelFactoryProtocol @@ -88,7 +96,6 @@ __all__ = [ "ActivityHandler", "Agent", - "OAuthFlow", "CardFactory", "ChannelAdapter", "ChannelApiHandlerProtocol", @@ -155,4 +162,10 @@ "StoreItem", "Storage", "MemoryStorage", + "FlowState", + "FlowStateTag", + "FlowErrorTag", + "FlowResponse", + "FlowStorageClient", + "OAuthFlow", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py index 124de7c5..4089c3fb 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/__init__.py @@ -13,12 +13,11 @@ from .route import Route, RouteHandler from .typing_indicator import TypingIndicator -# OAuth -from .oauth.authorization import ( +# Auth +from .oauth import ( Authorization, - AuthorizationHandlers, AuthHandler, - SignInState, + AuthorizationHandlers, ) # App State @@ -49,7 +48,6 @@ "TurnState", "TempState", "Authorization", - "AuthorizationHandlers", "AuthHandler", - "SignInState", + "AuthorizationHandlers", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py index 142f311c..6084a36c 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/agent_application.py @@ -1,800 +1,897 @@ -""" -Copyright (c) Microsoft Corporation. All rights reserved. -Licensed under the MIT License. -""" - -from __future__ import annotations -import logging -from copy import copy -from functools import partial - -import re -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generic, - List, - Optional, - Pattern, - TypeVar, - Union, - cast, -) - -from microsoft.agents.hosting.core.authorization import Connections - -from microsoft.agents.hosting.core import Agent, TurnContext -from microsoft.agents.activity import ( - Activity, - ActivityTypes, - ConversationUpdateTypes, - MessageReactionTypes, - MessageUpdateTypes, - InvokeResponse, -) - -from .app_error import ApplicationError -from .app_options import ApplicationOptions - -# from .auth import AuthManager, OAuth, OAuthOptions -from .route import Route, RouteHandler -from .state import TurnState -from ..channel_service_adapter import ChannelServiceAdapter -from .oauth import Authorization, SignInState -from .typing_indicator import TypingIndicator - -logger = logging.getLogger(__name__) - -StateT = TypeVar("StateT", bound=TurnState) -IN_SIGN_IN_KEY = "__InSignInFlow__" - - -class AgentApplication(Agent, Generic[StateT]): - """ - AgentApplication class for routing and processing incoming requests. - - The AgentApplication object replaces the traditional ActivityHandler that - a bot would use. It supports a simpler fluent style of authoring bots - versus the inheritance based approach used by the ActivityHandler class. - - Additionally, it has built-in support for calling into the SDK's AI system - and can be used to create bots that leverage Large Language Models (LLM) - and other AI capabilities. - """ - - typing: TypingIndicator - - _options: ApplicationOptions - _adapter: Optional[ChannelServiceAdapter] = None - _auth: Optional[Authorization] = None - _internal_before_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] - _internal_after_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] - _routes: List[Route[StateT]] = [] - _error: Optional[Callable[[TurnContext, Exception], Awaitable[None]]] = None - _turn_state_factory: Optional[Callable[[TurnContext], StateT]] = None - - def __init__( - self, - options: ApplicationOptions = None, - *, - connection_manager: Connections = None, - authorization: Authorization = None, - **kwargs, - ) -> None: - """ - Creates a new AgentApplication instance. - """ - self.typing = TypingIndicator() - self._routes = [] - - configuration = kwargs - - logger.debug(f"Initializing AgentApplication with options: {options}") - logger.debug( - f"Initializing AgentApplication with configuration: {configuration}" - ) - - if not options: - # TODO: consolidate configuration story - # Take the options from the kwargs and create an ApplicationOptions instance - option_kwargs = dict( - filter( - lambda x: x[0] in ApplicationOptions.__dataclass_fields__, - kwargs.items(), - ) - ) - options = ApplicationOptions(**option_kwargs) - - self._options = options - - if not self._options.storage: - logger.error( - "ApplicationOptions.storage is required and was not configured.", - stack_info=True, - ) - raise ApplicationError( - """ - The `ApplicationOptions.storage` property is required and was not configured. - """ - ) - - if options.long_running_messages and ( - not options.adapter or not options.bot_app_id - ): - logger.error( - "ApplicationOptions.long_running_messages requires an adapter and bot_app_id.", - stack_info=True, - ) - raise ApplicationError( - """ - The `ApplicationOptions.long_running_messages` property is unavailable because - no adapter or `bot_app_id` was configured. - """ - ) - - if options.adapter: - self._adapter = options.adapter - - self._turn_state_factory = ( - options.turn_state_factory - or kwargs.get("turn_state_factory", None) - or partial(TurnState.with_storage, self._options.storage) - ) - - # TODO: decide how to initialize the Authorization (params vs options vs kwargs) - if authorization: - self._auth = authorization - else: - auth_options = { - key: value - for key, value in configuration.items() - if key not in ["storage", "connection_manager", "handlers"] - } - self._auth = Authorization( - storage=self._options.storage, - connection_manager=connection_manager, - handlers=options.authorization_handlers, - **auth_options, - ) - - @property - def adapter(self) -> ChannelServiceAdapter: - """ - The bot's adapter. - """ - - if not self._adapter: - logger.error( - "AgentApplication.adapter(): self._adapter is not configured.", - stack_info=True, - ) - raise ApplicationError( - """ - The AgentApplication.adapter property is unavailable because it was - not configured when creating the AgentApplication. - """ - ) - - return self._adapter - - @property - def auth(self): - """ - The application's authentication manager - """ - if not self._auth: - logger.error( - "AgentApplication.auth(): self._auth is not configured.", - stack_info=True, - ) - raise ApplicationError( - """ - The `AgentApplication.auth` property is unavailable because - no Auth options were configured. - """ - ) - - return self._auth - - @property - def options(self) -> ApplicationOptions: - """ - The application's configured options. - """ - return self._options - - def activity( - self, - activity_type: Union[str, ActivityTypes, List[Union[str, ActivityTypes]]], - *, - auth_handlers: Optional[List[str]] = None, - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.activity("event") - async def on_event(context: TurnContext, state: TurnState): - print("hello world!") - return True - ``` - - #### Args: - - `type`: The type of the activity - """ - - def __selector(context: TurnContext): - return activity_type == context.activity.type - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering activity handler for route handler {func.__name__} with type: {activity_type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def message( - self, - select: Union[str, Pattern[str], List[Union[str, Pattern[str]]]], - *, - auth_handlers: Optional[List[str]] = None, - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.message("hi") - async def on_hi_message(context: TurnContext, state: TurnState): - print("hello!") - return True - - #### Args: - - `select`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if context.activity.type != ActivityTypes.message: - return False - - text = context.activity.text if context.activity.text else "" - if isinstance(select, Pattern): - hits = re.fullmatch(select, text) - return hits is not None - - return text == select - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering message handler for route handler {func.__name__} with select: {select} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def conversation_update( - self, - type: ConversationUpdateTypes, - *, - auth_handlers: Optional[List[str]] = None, - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.conversation_update("channelCreated") - async def on_channel_created(context: TurnContext, state: TurnState): - print("a new channel was created!") - return True - - ``` - - #### Args: - - `type`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if context.activity.type != ActivityTypes.conversation_update: - return False - - if type == "membersAdded": - if isinstance(context.activity.members_added, List): - return len(context.activity.members_added) > 0 - return False - - if type == "membersRemoved": - if isinstance(context.activity.members_removed, List): - return len(context.activity.members_removed) > 0 - return False - - if isinstance(context.activity.channel_data, object): - data = vars(context.activity.channel_data) - return data["event_type"] == type - - return False - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering conversation update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def message_reaction( - self, type: MessageReactionTypes, *, auth_handlers: Optional[List[str]] = None - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.message_reaction("reactionsAdded") - async def on_reactions_added(context: TurnContext, state: TurnState): - print("reactions was added!") - return True - ``` - - #### Args: - - `type`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if context.activity.type != ActivityTypes.message_reaction: - return False - - if type == "reactionsAdded": - if isinstance(context.activity.reactions_added, List): - return len(context.activity.reactions_added) > 0 - return False - - if type == "reactionsRemoved": - if isinstance(context.activity.reactions_removed, List): - return len(context.activity.reactions_removed) > 0 - return False - - return False - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering message reaction handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def message_update( - self, type: MessageUpdateTypes, *, auth_handlers: Optional[List[str]] = None - ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: - """ - Registers a new message activity event listener. This method can be used as either - a decorator or a method. - - ```python - # Use this method as a decorator - @app.message_update("editMessage") - async def on_edit_message(context: TurnContext, state: TurnState): - print("message was edited!") - return True - ``` - - #### Args: - - `type`: a string or regex pattern - """ - - def __selector(context: TurnContext): - if type == "editMessage": - if ( - context.activity.type == ActivityTypes.message_update - and isinstance(context.activity.channel_data, dict) - ): - data = context.activity.channel_data - return data["event_type"] == type - return False - - if type == "softDeleteMessage": - if ( - context.activity.type == ActivityTypes.message_delete - and isinstance(context.activity.channel_data, dict) - ): - data = context.activity.channel_data - return data["event_type"] == type - return False - - if type == "undeleteMessage": - if ( - context.activity.type == ActivityTypes.message_update - and isinstance(context.activity.channel_data, dict) - ): - data = context.activity.channel_data - return data["event_type"] == type - return False - return False - - def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: - logger.debug( - f"Registering message update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" - ) - self._routes.append( - Route[StateT](__selector, func, auth_handlers=auth_handlers) - ) - return func - - return __call - - def handoff(self, *, auth_handlers: Optional[List[str]] = None) -> Callable[ - [Callable[[TurnContext, StateT, str], Awaitable[None]]], - Callable[[TurnContext, StateT, str], Awaitable[None]], - ]: - """ - Registers a handler to handoff conversations from one copilot to another. - ```python - # Use this method as a decorator - @app.handoff - async def on_handoff( - context: TurnContext, state: TurnState, continuation: str - ): - print(query) - ``` - """ - - def __selector(context: TurnContext) -> bool: - return ( - context.activity.type == ActivityTypes.invoke - and context.activity.name == "handoff/action" - ) - - def __call( - func: Callable[[TurnContext, StateT, str], Awaitable[None]], - ) -> Callable[[TurnContext, StateT, str], Awaitable[None]]: - async def __handler(context: TurnContext, state: StateT): - if not context.activity.value: - return False - await func(context, state, context.activity.value["continuation"]) - await context.send_activity( - Activity( - type=ActivityTypes.invoke_response, - value=InvokeResponse(status=200), - ) - ) - return True - - logger.debug( - f"Registering handoff handler for route handler {func.__name__} with auth handlers: {auth_handlers}" - ) - - self._routes.append( - Route[StateT](__selector, __handler, True, auth_handlers) - ) - self._routes = sorted(self._routes, key=lambda route: not route.is_invoke) - return func - - return __call - - def on_sign_in_success( - self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] - ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: - """ - Registers a new event listener that will be executed when a user successfully signs in. - - ```python - # Use this method as a decorator - @app.on_sign_in_success - async def sign_in_success(context: TurnContext, state: TurnState): - print("hello world!") - return True - ``` - """ - - if self._auth: - logger.debug( - f"Registering sign-in success handler for route handler {func.__name__}" - ) - self._auth.on_sign_in_success(func) - else: - logger.error( - f"Failed to register sign-in success handler for route handler {func.__name__}", - stack_info=True, - ) - raise ApplicationError( - """ - The `AgentApplication.on_sign_in_success` method is unavailable because - no Auth options were configured. - """ - ) - return func - - def on_sign_in_failure( - self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] - ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: - """ - Registers a new event listener that will be executed when a user fails to sign in. - - ```python - # Use this method as a decorator - @app.on_sign_in_failure - async def sign_in_failure(context: TurnContext, state: TurnState): - print("hello world!") - return True - ``` - """ - - if self._auth: - logger.debug( - f"Registering sign-in failure handler for route handler {func.__name__}" - ) - self._auth.on_sign_in_failure(func) - else: - logger.error( - f"Failed to register sign-in failure handler for route handler {func.__name__}", - stack_info=True, - ) - raise ApplicationError( - """ - The `AgentApplication.on_sign_in_failure` method is unavailable because - no Auth options were configured. - """ - ) - return func - - def error( - self, func: Callable[[TurnContext, Exception], Awaitable[None]] - ) -> Callable[[TurnContext, Exception], Awaitable[None]]: - """ - Registers an error handler that will be called anytime - the app throws an Exception - - ```python - # Use this method as a decorator - @app.error - async def on_error(context: TurnContext, err: Exception): - print(err.message) - ``` - """ - - logger.debug(f"Registering the error handler {func.__name__} ") - self._error = func - - if self._adapter: - logger.debug( - f"Registering for adapter {self._adapter.__class__.__name__} the error handler {func.__name__} " - ) - self._adapter.on_turn_error = func - - return func - - def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): - """ - Custom Turn State Factory - """ - logger.debug(f"Setting custom turn state factory: {func.__name__}") - self._turn_state_factory = func - return func - - async def on_turn(self, context: TurnContext): - logger.debug( - f"AgentApplication.on_turn(): Processing turn for context: {context.activity.id}" - ) - await self._start_long_running_call(context, self._on_turn) - - async def _on_turn(self, context: TurnContext): - try: - await self._start_typing(context) - - self._remove_mentions(context) - - logger.debug("Initializing turn state") - turn_state = await self._initialize_state(context) - - sign_in_state = turn_state.get_value( - Authorization.SIGN_IN_STATE_KEY, target_cls=SignInState - ) - logger.debug( - f"Sign-in state: {sign_in_state} for context: {context.activity.id}" - ) - - if self._auth and sign_in_state and not sign_in_state.completed: - flow_state = self._auth.get_flow_state(sign_in_state.handler_id) - logger.debug("Flow state: %s", flow_state) - if flow_state.flow_started: - logger.debug("Continuing sign-in flow") - token_response = await self._auth.begin_or_continue_flow( - context, turn_state, sign_in_state.handler_id - ) - saved_activity = sign_in_state.continuation_activity.model_copy() - if token_response and token_response.token: - new_context = copy(context) - new_context.activity = saved_activity - logger.info( - "Resending continuation activity %s", saved_activity.text - ) - await self.on_turn(new_context) - turn_state.delete_value(Authorization.SIGN_IN_STATE_KEY) - await turn_state.save(context) - return - - logger.debug("Running before turn middleware") - if not await self._run_before_turn_middleware(context, turn_state): - return - - logger.debug("Running file downloads") - await self._handle_file_downloads(context, turn_state) - - logger.debug("Running activity handlers") - await self._on_activity(context, turn_state) - - logger.debug("Running after turn middleware") - if not await self._run_after_turn_middleware(context, turn_state): - await turn_state.save(context) - return - except ApplicationError as err: - logger.error( - f"An application error occurred in the AgentApplication: {err}", - exc_info=True, - ) - await self._on_error(context, err) - finally: - logger.debug("Stopping typing indicator") - self.typing.stop() - - async def _start_typing(self, context: TurnContext): - if self._options.start_typing_timer: - logger.debug("Starting typing indicator for context") - await self.typing.start(context) - - def _remove_mentions(self, context: TurnContext): - if ( - self.options.remove_recipient_mention - and context.activity.type == ActivityTypes.message - ): - context.activity.text = context.remove_recipient_mention(context.activity) - - @staticmethod - def parse_env_vars_configuration(vars: Dict[str, Any]) -> dict: - """ - Parses environment variables and returns a dictionary with the relevant configuration. - """ - result = {} - for key, value in vars.items(): - levels = key.split("__") - current_level = result - last_level = None - for next_level in levels: - if next_level not in current_level: - current_level[next_level] = {} - last_level = current_level - current_level = current_level[next_level] - logger.debug(f"Using environment variable '{key}'") - last_level[levels[-1]] = value - - return { - "AGENT_APPLICATION": result["AGENT_APPLICATION"], - "COPILOT_STUDIO_AGENT": result["COPILOT_STUDIO_AGENT"], - "CONNECTIONS": result["CONNECTIONS"], - "CONNECTIONS_MAP": result["CONNECTIONS_MAP"], - } - - async def _initialize_state(self, context: TurnContext) -> StateT: - if self._turn_state_factory: - logger.debug("Using custom turn state factory") - turn_state = self._turn_state_factory() - else: - logger.debug("Using default turn state factory") - turn_state = TurnState.with_storage(self._options.storage) - await turn_state.load(context, self._options.storage) - - turn_state = cast(StateT, turn_state) - - logger.debug("Loading turn state from storage") - await turn_state.load(context, self._options.storage) - turn_state.temp.input = context.activity.text - return turn_state - - async def _run_before_turn_middleware(self, context: TurnContext, state: StateT): - for before_turn in self._internal_before_turn: - is_ok = await before_turn(context, state) - if not is_ok: - await state.save(context, self._options.storage) - return False - return True - - async def _handle_file_downloads(self, context: TurnContext, state: StateT): - if self._options.file_downloaders and len(self._options.file_downloaders) > 0: - input_files = state.temp.input_files if state.temp.input_files else [] - for file_downloader in self._options.file_downloaders: - logger.info( - f"Using file downloader: {file_downloader.__class__.__name__}" - ) - files = await file_downloader.download_files(context) - input_files.extend(files) - state.temp.input_files = input_files - - def _contains_non_text_attachments(self, context: TurnContext): - non_text_attachments = filter( - lambda a: not a.content_type.startswith("text/html"), - context.activity.attachments, - ) - return len(list(non_text_attachments)) > 0 - - async def _run_after_turn_middleware(self, context: TurnContext, state: StateT): - for after_turn in self._internal_after_turn: - is_ok = await after_turn(context, state) - if not is_ok: - await state.save(context, self._options.storage) - return False - return True - - async def _on_activity(self, context: TurnContext, state: StateT): - for route in self._routes: - if route.selector(context): - if not route.auth_handlers: - await route.handler(context, state) - else: - sign_in_complete = False - for auth_handler_id in route.auth_handlers: - token_response = await self._auth.begin_or_continue_flow( - context, state, auth_handler_id - ) - sign_in_complete = token_response and token_response.token - if not sign_in_complete: - break - if sign_in_complete: - await route.handler(context, state) - return - logger.warning( - f"No route found for activity type: {context.activity.type} with text: {context.activity.text}" - ) - - async def _start_long_running_call( - self, context: TurnContext, func: Callable[[TurnContext], Awaitable] - ): - if ( - self._adapter - and ActivityTypes.message == context.activity.type - and self._options.long_running_messages - ): - logger.debug( - f"Starting long running call for context: {context.activity.id} with function: {func.__name__}" - ) - return await self._adapter.continue_conversation( - reference=context.get_conversation_reference(context.activity), - callback=func, - bot_app_id=self.options.bot_app_id, - ) - - return await func(context) - - async def _on_error(self, context: TurnContext, err: ApplicationError) -> None: - if self._error: - logger.info( - f"Calling error handler {self._error.__name__} for error: {err}" - ) - return await self._error(context, err) - - logger.error( - f"An error occurred in the AgentApplication: {err}", - exc_info=True, - ) - logger.error(err) - raise err +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from __future__ import annotations +import logging +from copy import copy +from functools import partial + +import re +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + List, + Optional, + Pattern, + TypeVar, + Union, + cast, +) + +from microsoft.agents.hosting.core.authorization import Connections + +from microsoft.agents.hosting.core import Agent, TurnContext +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + ActionTypes, + ConversationUpdateTypes, + MessageReactionTypes, + MessageUpdateTypes, + InvokeResponse, + TokenResponse, + OAuthCard, + Attachment, + CardAction, +) + +from .. import CardFactory, MessageFactory +from .app_error import ApplicationError +from .app_options import ApplicationOptions + +# from .auth import AuthManager, OAuth, OAuthOptions +from .route import Route, RouteHandler +from .state import TurnState +from ..channel_service_adapter import ChannelServiceAdapter +from ..oauth import ( + FlowResponse, + FlowState, + FlowStateTag, +) +from .oauth import Authorization +from .typing_indicator import TypingIndicator + +logger = logging.getLogger(__name__) + +StateT = TypeVar("StateT", bound=TurnState) +IN_SIGN_IN_KEY = "__InSignInFlow__" + + +class AgentApplication(Agent, Generic[StateT]): + """ + AgentApplication class for routing and processing incoming requests. + + The AgentApplication object replaces the traditional ActivityHandler that + a bot would use. It supports a simpler fluent style of authoring bots + versus the inheritance based approach used by the ActivityHandler class. + + Additionally, it has built-in support for calling into the SDK's AI system + and can be used to create bots that leverage Large Language Models (LLM) + and other AI capabilities. + """ + + typing: TypingIndicator + + _options: ApplicationOptions + _adapter: Optional[ChannelServiceAdapter] = None + _auth: Optional[Authorization] = None + _internal_before_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] + _internal_after_turn: List[Callable[[TurnContext, StateT], Awaitable[bool]]] = [] + _routes: List[Route[StateT]] = [] + _error: Optional[Callable[[TurnContext, Exception], Awaitable[None]]] = None + _turn_state_factory: Optional[Callable[[TurnContext], StateT]] = None + + def __init__( + self, + options: ApplicationOptions = None, + *, + connection_manager: Connections = None, + authorization: Authorization = None, + **kwargs, + ) -> None: + """ + Creates a new AgentApplication instance. + """ + self.typing = TypingIndicator() + self._routes = [] + + configuration = kwargs + + logger.debug(f"Initializing AgentApplication with options: {options}") + logger.debug( + f"Initializing AgentApplication with configuration: {configuration}" + ) + + if not options: + # TODO: consolidate configuration story + # Take the options from the kwargs and create an ApplicationOptions instance + option_kwargs = dict( + filter( + lambda x: x[0] in ApplicationOptions.__dataclass_fields__, + kwargs.items(), + ) + ) + options = ApplicationOptions(**option_kwargs) + + self._options = options + + if not self._options.storage: + logger.error( + "ApplicationOptions.storage is required and was not configured.", + stack_info=True, + ) + raise ApplicationError( + """ + The `ApplicationOptions.storage` property is required and was not configured. + """ + ) + + if options.long_running_messages and ( + not options.adapter or not options.bot_app_id + ): + logger.error( + "ApplicationOptions.long_running_messages requires an adapter and bot_app_id.", + stack_info=True, + ) + raise ApplicationError( + """ + The `ApplicationOptions.long_running_messages` property is unavailable because + no adapter or `bot_app_id` was configured. + """ + ) + + if options.adapter: + self._adapter = options.adapter + + self._turn_state_factory = ( + options.turn_state_factory + or kwargs.get("turn_state_factory", None) + or partial(TurnState.with_storage, self._options.storage) + ) + + # TODO: decide how to initialize the Authorization (params vs options vs kwargs) + if authorization: + self._auth = authorization + else: + auth_options = { + key: value + for key, value in configuration.items() + if key not in ["storage", "connection_manager", "handlers"] + } + self._auth = Authorization( + storage=self._options.storage, + connection_manager=connection_manager, + handlers=options.authorization_handlers, + **auth_options, + ) + + @property + def adapter(self) -> ChannelServiceAdapter: + """ + The bot's adapter. + """ + + if not self._adapter: + logger.error( + "AgentApplication.adapter(): self._adapter is not configured.", + stack_info=True, + ) + raise ApplicationError( + """ + The AgentApplication.adapter property is unavailable because it was + not configured when creating the AgentApplication. + """ + ) + + return self._adapter + + @property + def auth(self): + """ + The application's authentication manager + """ + if not self._auth: + logger.error( + "AgentApplication.auth(): self._auth is not configured.", + stack_info=True, + ) + raise ApplicationError( + """ + The `AgentApplication.auth` property is unavailable because + no Auth options were configured. + """ + ) + + return self._auth + + @property + def options(self) -> ApplicationOptions: + """ + The application's configured options. + """ + return self._options + + def activity( + self, + activity_type: Union[str, ActivityTypes, List[Union[str, ActivityTypes]]], + *, + auth_handlers: Optional[List[str]] = None, + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.activity("event") + async def on_event(context: TurnContext, state: TurnState): + print("hello world!") + return True + ``` + + #### Args: + - `type`: The type of the activity + """ + + def __selector(context: TurnContext): + return activity_type == context.activity.type + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering activity handler for route handler {func.__name__} with type: {activity_type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def message( + self, + select: Union[str, Pattern[str], List[Union[str, Pattern[str]]]], + *, + auth_handlers: Optional[List[str]] = None, + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.message("hi") + async def on_hi_message(context: TurnContext, state: TurnState): + print("hello!") + return True + + #### Args: + - `select`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if context.activity.type != ActivityTypes.message: + return False + + text = context.activity.text if context.activity.text else "" + if isinstance(select, Pattern): + hits = re.fullmatch(select, text) + return hits is not None + + return text == select + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering message handler for route handler {func.__name__} with select: {select} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def conversation_update( + self, + type: ConversationUpdateTypes, + *, + auth_handlers: Optional[List[str]] = None, + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.conversation_update("channelCreated") + async def on_channel_created(context: TurnContext, state: TurnState): + print("a new channel was created!") + return True + + ``` + + #### Args: + - `type`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if context.activity.type != ActivityTypes.conversation_update: + return False + + if type == "membersAdded": + if isinstance(context.activity.members_added, List): + return len(context.activity.members_added) > 0 + return False + + if type == "membersRemoved": + if isinstance(context.activity.members_removed, List): + return len(context.activity.members_removed) > 0 + return False + + if isinstance(context.activity.channel_data, object): + data = vars(context.activity.channel_data) + return data["event_type"] == type + + return False + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering conversation update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def message_reaction( + self, type: MessageReactionTypes, *, auth_handlers: Optional[List[str]] = None + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.message_reaction("reactionsAdded") + async def on_reactions_added(context: TurnContext, state: TurnState): + print("reactions was added!") + return True + ``` + + #### Args: + - `type`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if context.activity.type != ActivityTypes.message_reaction: + return False + + if type == "reactionsAdded": + if isinstance(context.activity.reactions_added, List): + return len(context.activity.reactions_added) > 0 + return False + + if type == "reactionsRemoved": + if isinstance(context.activity.reactions_removed, List): + return len(context.activity.reactions_removed) > 0 + return False + + return False + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering message reaction handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def message_update( + self, type: MessageUpdateTypes, *, auth_handlers: Optional[List[str]] = None + ) -> Callable[[RouteHandler[StateT]], RouteHandler[StateT]]: + """ + Registers a new message activity event listener. This method can be used as either + a decorator or a method. + + ```python + # Use this method as a decorator + @app.message_update("editMessage") + async def on_edit_message(context: TurnContext, state: TurnState): + print("message was edited!") + return True + ``` + + #### Args: + - `type`: a string or regex pattern + """ + + def __selector(context: TurnContext): + if type == "editMessage": + if ( + context.activity.type == ActivityTypes.message_update + and isinstance(context.activity.channel_data, dict) + ): + data = context.activity.channel_data + return data["event_type"] == type + return False + + if type == "softDeleteMessage": + if ( + context.activity.type == ActivityTypes.message_delete + and isinstance(context.activity.channel_data, dict) + ): + data = context.activity.channel_data + return data["event_type"] == type + return False + + if type == "undeleteMessage": + if ( + context.activity.type == ActivityTypes.message_update + and isinstance(context.activity.channel_data, dict) + ): + data = context.activity.channel_data + return data["event_type"] == type + return False + return False + + def __call(func: RouteHandler[StateT]) -> RouteHandler[StateT]: + logger.debug( + f"Registering message update handler for route handler {func.__name__} with type: {type} with auth handlers: {auth_handlers}" + ) + self._routes.append( + Route[StateT](__selector, func, auth_handlers=auth_handlers) + ) + return func + + return __call + + def handoff(self, *, auth_handlers: Optional[List[str]] = None) -> Callable[ + [Callable[[TurnContext, StateT, str], Awaitable[None]]], + Callable[[TurnContext, StateT, str], Awaitable[None]], + ]: + """ + Registers a handler to handoff conversations from one copilot to another. + ```python + # Use this method as a decorator + @app.handoff + async def on_handoff( + context: TurnContext, state: TurnState, continuation: str + ): + print(query) + ``` + """ + + def __selector(context: TurnContext) -> bool: + return ( + context.activity.type == ActivityTypes.invoke + and context.activity.name == "handoff/action" + ) + + def __call( + func: Callable[[TurnContext, StateT, str], Awaitable[None]], + ) -> Callable[[TurnContext, StateT, str], Awaitable[None]]: + async def __handler(context: TurnContext, state: StateT): + if not context.activity.value: + return False + await func(context, state, context.activity.value["continuation"]) + await context.send_activity( + Activity( + type=ActivityTypes.invoke_response, + value=InvokeResponse(status=200), + ) + ) + return True + + logger.debug( + f"Registering handoff handler for route handler {func.__name__} with auth handlers: {auth_handlers}" + ) + + self._routes.append( + Route[StateT](__selector, __handler, True, auth_handlers) + ) + self._routes = sorted(self._routes, key=lambda route: not route.is_invoke) + return func + + return __call + + def on_sign_in_success( + self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] + ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: + """ + Registers a new event listener that will be executed when a user successfully signs in. + + ```python + # Use this method as a decorator + @app.on_sign_in_success + async def sign_in_success(context: TurnContext, state: TurnState): + print("hello world!") + return True + ``` + """ + + if self._auth: + logger.debug( + f"Registering sign-in success handler for route handler {func.__name__}" + ) + self._auth.on_sign_in_success(func) + else: + logger.error( + f"Failed to register sign-in success handler for route handler {func.__name__}", + stack_info=True, + ) + raise ApplicationError( + """ + The `AgentApplication.on_sign_in_success` method is unavailable because + no Auth options were configured. + """ + ) + return func + + def on_sign_in_failure( + self, func: Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]] + ) -> Callable[[TurnContext, StateT, Optional[str]], Awaitable[None]]: + """ + Registers a new event listener that will be executed when a user fails to sign in. + + ```python + # Use this method as a decorator + @app.on_sign_in_failure + async def sign_in_failure(context: TurnContext, state: TurnState): + print("hello world!") + return True + ``` + """ + + if self._auth: + logger.debug( + f"Registering sign-in failure handler for route handler {func.__name__}" + ) + self._auth.on_sign_in_failure(func) + else: + logger.error( + f"Failed to register sign-in failure handler for route handler {func.__name__}", + stack_info=True, + ) + raise ApplicationError( + """ + The `AgentApplication.on_sign_in_failure` method is unavailable because + no Auth options were configured. + """ + ) + return func + + def error( + self, func: Callable[[TurnContext, Exception], Awaitable[None]] + ) -> Callable[[TurnContext, Exception], Awaitable[None]]: + """ + Registers an error handler that will be called anytime + the app throws an Exception + + ```python + # Use this method as a decorator + @app.error + async def on_error(context: TurnContext, err: Exception): + print(err.message) + ``` + """ + + logger.debug(f"Registering the error handler {func.__name__} ") + self._error = func + + if self._adapter: + logger.debug( + f"Registering for adapter {self._adapter.__class__.__name__} the error handler {func.__name__} " + ) + self._adapter.on_turn_error = func + + return func + + def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): + """ + Custom Turn State Factory + """ + logger.debug(f"Setting custom turn state factory: {func.__name__}") + self._turn_state_factory = func + return func + + async def _handle_flow_response( + self, context: TurnContext, flow_response: FlowResponse + ) -> None: + """Handles CONTINUE and FAILURE flow responses, sending activities back.""" + flow_state: FlowState = flow_response.flow_state + + if flow_state.tag == FlowStateTag.BEGIN: + # Create the OAuth card + sign_in_resource = flow_response.sign_in_resource + o_card: Attachment = CardFactory.oauth_card( + OAuthCard( + text="Sign in", + connection_name=flow_state.connection, + buttons=[ + CardAction( + title="Sign in", + type=ActionTypes.signin, + value=sign_in_resource.sign_in_link, + channel_data=None, + ) + ], + token_exchange_resource=sign_in_resource.token_exchange_resource, + token_post_resource=sign_in_resource.token_post_resource, + ) + ) + # Send the card to the user + await context.send_activity(MessageFactory.attachment(o_card)) + elif flow_state.tag == FlowStateTag.FAILURE: + if flow_state.reached_max_attempts(): + await context.send_activity( + MessageFactory.text( + "Sign-in failed. Max retries reached. Please try again later." + ) + ) + elif flow_state.is_expired(): + await context.send_activity( + MessageFactory.text("Sign-in session expired. Please try again.") + ) + else: + logger.warning("Sign-in flow failed for unknown reasons.") + await context.send_activity("Sign-in failed. Please try again.") + + async def _on_turn_auth_intercept( + self, context: TurnContext, turn_state: TurnState + ) -> bool: + """Intercepts the turn to check for active authentication flows.""" + logger.debug( + "Checking for active sign-in flow for context: %s with activity type %s", + context.activity.id, + context.activity.type, + ) + prev_flow_state = await self._auth.get_active_flow_state(context) + if prev_flow_state: + logger.debug( + "Previous flow state: %s", + { + "user_id": prev_flow_state.user_id, + "connection": prev_flow_state.connection, + "channel_id": prev_flow_state.channel_id, + "auth_handler_id": prev_flow_state.auth_handler_id, + "tag": prev_flow_state.tag, + "expiration": prev_flow_state.expiration, + }, + ) + # proceed if there is an existing flow to continue + # new flows should be initiated in _on_activity + # this can be reorganized later... but it works for now + if ( + prev_flow_state + and ( + prev_flow_state.tag == FlowStateTag.NOT_STARTED + or prev_flow_state.is_active() + ) + and context.activity.type in [ActivityTypes.message, ActivityTypes.invoke] + ): + + logger.debug("Sign-in flow is active for context: %s", context.activity.id) + + flow_response: FlowResponse = await self._auth.begin_or_continue_flow( + context, turn_state, prev_flow_state.auth_handler_id + ) + + await self._handle_flow_response(context, flow_response) + + new_flow_state: FlowState = flow_response.flow_state + token_response: TokenResponse = flow_response.token_response + saved_activity: Activity = new_flow_state.continuation_activity.model_copy() + + if token_response: + new_context = copy(context) + new_context.activity = saved_activity + logger.info("Resending continuation activity %s", saved_activity.text) + await self.on_turn(new_context) + await turn_state.save(context) + return True # early return from _on_turn + return False # continue _on_turn + + async def on_turn(self, context: TurnContext): + logger.debug( + f"AgentApplication.on_turn(): Processing turn for context: {context.activity.id}" + ) + await self._start_long_running_call(context, self._on_turn) + + async def _on_turn(self, context: TurnContext): + # robrandao: TODO + try: + await self._start_typing(context) + + self._remove_mentions(context) + + logger.debug("Initializing turn state") + turn_state = await self._initialize_state(context) + + if self._auth and await self._on_turn_auth_intercept(context, turn_state): + return + + logger.debug("Running before turn middleware") + if not await self._run_before_turn_middleware(context, turn_state): + return + + logger.debug("Running file downloads") + await self._handle_file_downloads(context, turn_state) + + logger.debug("Running activity handlers") + await self._on_activity(context, turn_state) + + logger.debug("Running after turn middleware") + if not await self._run_after_turn_middleware(context, turn_state): + await turn_state.save(context) + return + except ApplicationError as err: + logger.error( + f"An application error occurred in the AgentApplication: {err}", + exc_info=True, + ) + await self._on_error(context, err) + finally: + self.typing.stop() + + async def _start_typing(self, context: TurnContext): + if self._options.start_typing_timer: + await self.typing.start(context) + + def _remove_mentions(self, context: TurnContext): + if ( + self.options.remove_recipient_mention + and context.activity.type == ActivityTypes.message + ): + context.activity.text = context.remove_recipient_mention(context.activity) + + @staticmethod + def parse_env_vars_configuration(vars: Dict[str, Any]) -> dict: + """ + Parses environment variables and returns a dictionary with the relevant configuration. + """ + result = {} + for key, value in vars.items(): + levels = key.split("__") + current_level = result + last_level = None + for next_level in levels: + if next_level not in current_level: + current_level[next_level] = {} + last_level = current_level + current_level = current_level[next_level] + logger.debug(f"Using environment variable '{key}'") + last_level[levels[-1]] = value + + return { + "AGENT_APPLICATION": result["AGENT_APPLICATION"], + "COPILOT_STUDIO_AGENT": result["COPILOT_STUDIO_AGENT"], + "CONNECTIONS": result["CONNECTIONS"], + "CONNECTIONS_MAP": result["CONNECTIONS_MAP"], + } + + async def _initialize_state(self, context: TurnContext) -> StateT: + if self._turn_state_factory: + logger.debug("Using custom turn state factory") + turn_state = self._turn_state_factory() + else: + logger.debug("Using default turn state factory") + turn_state = TurnState.with_storage(self._options.storage) + await turn_state.load(context, self._options.storage) + + turn_state = cast(StateT, turn_state) + + logger.debug("Loading turn state from storage") + await turn_state.load(context, self._options.storage) + turn_state.temp.input = context.activity.text + return turn_state + + async def _run_before_turn_middleware(self, context: TurnContext, state: StateT): + for before_turn in self._internal_before_turn: + is_ok = await before_turn(context, state) + if not is_ok: + await state.save(context, self._options.storage) + return False + return True + + async def _handle_file_downloads(self, context: TurnContext, state: StateT): + if self._options.file_downloaders and len(self._options.file_downloaders) > 0: + input_files = state.temp.input_files if state.temp.input_files else [] + for file_downloader in self._options.file_downloaders: + logger.info( + f"Using file downloader: {file_downloader.__class__.__name__}" + ) + files = await file_downloader.download_files(context) + input_files.extend(files) + state.temp.input_files = input_files + + def _contains_non_text_attachments(self, context: TurnContext): + non_text_attachments = filter( + lambda a: not a.content_type.startswith("text/html"), + context.activity.attachments, + ) + return len(list(non_text_attachments)) > 0 + + async def _run_after_turn_middleware(self, context: TurnContext, state: StateT): + for after_turn in self._internal_after_turn: + is_ok = await after_turn(context, state) + if not is_ok: + await state.save(context, self._options.storage) + return False + return True + + async def _on_activity(self, context: TurnContext, state: StateT): + for route in self._routes: + if route.selector(context): + if not route.auth_handlers: + await route.handler(context, state) + else: + sign_in_complete = False + for auth_handler_id in route.auth_handlers: + logger.debug( + "Beginning or continuing flow for auth handler %s", + auth_handler_id, + ) + flow_response: FlowResponse = ( + await self._auth.begin_or_continue_flow( + context, state, auth_handler_id + ) + ) + await self._handle_flow_response(context, flow_response) + logger.debug( + "Flow response flow_state.tag: %s", + flow_response.flow_state.tag, + ) + sign_in_complete = ( + flow_response.flow_state.tag == FlowStateTag.COMPLETE + ) + if not sign_in_complete: + break + + if sign_in_complete: + await route.handler(context, state) + return + logger.warning( + f"No route found for activity type: {context.activity.type} with text: {context.activity.text}" + ) + + async def _start_long_running_call( + self, context: TurnContext, func: Callable[[TurnContext], Awaitable] + ): + if ( + self._adapter + and ActivityTypes.message == context.activity.type + and self._options.long_running_messages + ): + logger.debug( + f"Starting long running call for context: {context.activity.id} with function: {func.__name__}" + ) + return await self._adapter.continue_conversation( + reference=context.get_conversation_reference(context.activity), + callback=func, + bot_app_id=self.options.bot_app_id, + ) + + return await func(context) + + async def _on_error(self, context: TurnContext, err: ApplicationError) -> None: + if self._error: + logger.info( + f"Calling error handler {self._error.__name__} for error: {err}" + ) + return await self._error(context, err) + + logger.error( + f"An error occurred in the AgentApplication: {err}", + exc_info=True, + ) + logger.error(err) + raise err diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py index c8d125cb..e0d871de 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/app_options.py @@ -9,7 +9,7 @@ from logging import Logger from typing import Callable, List, Optional -from microsoft.agents.hosting.core.app.oauth.authorization import AuthorizationHandlers +from microsoft.agents.hosting.core.app.oauth import AuthHandler from microsoft.agents.hosting.core.storage import Storage # from .auth import AuthOptions @@ -84,7 +84,7 @@ class ApplicationOptions: If not provided, the default `TurnState` will be used. """ - authorization_handlers: Optional[AuthorizationHandlers] = None + authorization_handlers: Optional[dict[str, AuthHandler]] = None """ Optional. Authorization handler for OAuth flows. If not provided, no OAuth flows will be supported. diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py index ff280c7f..7c962a43 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/__init__.py @@ -1,8 +1,8 @@ -from .authorization import ( - Authorization, - AuthorizationHandlers, - AuthHandler, - SignInState, -) +from .authorization import Authorization +from .auth_handler import AuthHandler, AuthorizationHandlers -__all__ = ["Authorization", "AuthorizationHandlers", "AuthHandler", "SignInState"] +__all__ = [ + "Authorization", + "AuthHandler", + "AuthorizationHandlers", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py new file mode 100644 index 00000000..ddde6e9a --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/auth_handler.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import logging +from typing import Dict + +logger = logging.getLogger(__name__) + + +class AuthHandler: + """ + Interface defining an authorization handler for OAuth flows. + """ + + def __init__( + self, + name: str = None, + title: str = None, + text: str = None, + abs_oauth_connection_name: str = None, + obo_connection_name: str = None, + **kwargs, + ): + """ + Initializes a new instance of AuthHandler. + + Args: + name: The name of the OAuth connection. + auto: Whether to automatically start the OAuth flow. + title: Title for the OAuth card. + text: Text for the OAuth button. + """ + self.name = name or kwargs.get("NAME") + self.title = title or kwargs.get("TITLE") + self.text = text or kwargs.get("TEXT") + self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( + "AZUREBOTOAUTHCONNECTIONNAME" + ) + self.obo_connection_name = obo_connection_name or kwargs.get( + "OBOCONNECTIONNAME" + ) + logger.debug( + f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" + ) + + +# # Type alias for authorization handlers dictionary +AuthorizationHandlers = Dict[str, AuthHandler] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py index b49ee253..145ac895 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/app/oauth/authorization.py @@ -1,431 +1,398 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from __future__ import annotations -import logging -import jwt -from typing import Dict, Optional, Callable, Awaitable - -from microsoft.agents.hosting.core.authorization import ( - Connections, - AccessTokenProviderBase, -) -from microsoft.agents.hosting.core.storage import Storage -from microsoft.agents.activity import TokenResponse, Activity -from microsoft.agents.hosting.core.storage import StoreItem -from pydantic import BaseModel - -from ...turn_context import TurnContext -from ...app.state.turn_state import TurnState -from ...oauth_flow import OAuthFlow, FlowState -from ...state.user_state import UserState - -logger = logging.getLogger(__name__) - - -class SignInState(StoreItem, BaseModel): - """ - Interface defining the sign-in state for OAuth flows. - """ - - continuation_activity: Optional[Activity] = None - handler_id: Optional[str] = None - completed: Optional[bool] = False - - def store_item_to_json(self) -> dict: - return self.model_dump(exclude_unset=True) - - @staticmethod - def from_json_to_store_item(json_data: dict) -> "StoreItem": - return SignInState.model_validate(json_data) - - -class AuthHandler: - """ - Interface defining an authorization handler for OAuth flows. - """ - - def __init__( - self, - name: str = None, - title: str = None, - text: str = None, - abs_oauth_connection_name: str = None, - obo_connection_name: str = None, - **kwargs, - ): - """ - Initializes a new instance of AuthHandler. - - Args: - name: The name of the OAuth connection. - auto: Whether to automatically start the OAuth flow. - title: Title for the OAuth card. - text: Text for the OAuth button. - """ - self.name = name or kwargs.get("NAME") - self.title = title or kwargs.get("TITLE") - self.text = text or kwargs.get("TEXT") - self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( - "AZUREBOTOAUTHCONNECTIONNAME" - ) - self.obo_connection_name = obo_connection_name or kwargs.get( - "OBOCONNECTIONNAME" - ) - self.flow: OAuthFlow = None - logger.debug( - f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" - ) - - -# Type alias for authorization handlers dictionary -AuthorizationHandlers = Dict[str, AuthHandler] - - -class Authorization: - """ - Class responsible for managing authorization and OAuth flows. - """ - - SIGN_IN_STATE_KEY = f"{UserState.__name__}.__SIGNIN_STATE_" - - def __init__( - self, - storage: Storage, - connection_manager: Connections, - auth_handlers: AuthorizationHandlers = None, - auto_signin: bool = None, - **kwargs, - ): - """ - Creates a new instance of Authorization. - - Args: - storage: The storage system to use for state management. - auth_handlers: Configuration for OAuth providers. - - Raises: - ValueError: If storage is None or no auth handlers are provided. - """ - if storage is None: - logger.error("Storage is required for Authorization") - raise ValueError("Storage is required for Authorization") - - user_state = UserState(storage) - self._connection_manager = connection_manager - - auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( - "USERAUTHORIZATION", {} - ) - - self._auto_signin = ( - auto_signin - if auto_signin is not None - else auth_configuration.get("AUTOSIGNIN", False) - ) - - handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") - if not auth_handlers and handlers_config: - auth_handlers = { - handler_name: AuthHandler( - name=handler_name, **config.get("SETTINGS", {}) - ) - for handler_name, config in handlers_config.items() - } - - self._auth_handlers = auth_handlers or {} - self._sign_in_handler: Optional[ - Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = None - self._sign_in_failed_handler: Optional[ - Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = None - - # Configure each auth handler - for auth_handler in self._auth_handlers.values(): - # Create OAuth flow with configuration - messages_config = {} - if auth_handler.title: - messages_config["card_title"] = auth_handler.title - if auth_handler.text: - messages_config["button_text"] = auth_handler.text - - logger.debug(f"Configuring OAuth flow for handler: {auth_handler.name}") - auth_handler.flow = OAuthFlow( - storage=storage, - abs_oauth_connection_name=auth_handler.abs_oauth_connection_name, - messages_configuration=messages_config if messages_config else None, - ) - - async def get_token( - self, context: TurnContext, auth_handler_id: Optional[str] = None - ) -> TokenResponse: - """ - Gets the token for a specific auth handler. - - Args: - context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - """ - auth_handler = self.resolver_handler(auth_handler_id) - if auth_handler.flow is None: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") - - return await auth_handler.flow.get_user_token(context) - - async def exchange_token( - self, - context: TurnContext, - scopes: list[str], - auth_handler_id: Optional[str] = None, - ) -> TokenResponse: - """ - Exchanges a token for another token with different scopes. - - Args: - context: The context object for the current turn. - scopes: The scopes to request for the new token. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - """ - auth_handler = self.resolver_handler(auth_handler_id) - if not auth_handler.flow: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") - - token_response = await auth_handler.flow.get_user_token(context) - - if self._is_exchangeable(token_response.token if token_response else None): - return await self._handle_obo(token_response.token, scopes, auth_handler_id) - - return token_response - - def _is_exchangeable(self, token: Optional[str]) -> bool: - """ - Checks if a token is exchangeable (has api:// audience). - - Args: - token: The token to check. - - Returns: - True if the token is exchangeable, False otherwise. - """ - if not token: - return False - - try: - # Decode without verification to check the audience - payload = jwt.decode(token, options={"verify_signature": False}) - aud = payload.get("aud") - return isinstance(aud, str) and aud.startswith("api://") - except Exception: - logger.exception("Failed to decode token to check audience") - return False - - async def _handle_obo( - self, token: str, scopes: list[str], handler_id: str = None - ) -> TokenResponse: - """ - Handles On-Behalf-Of token exchange. - - Args: - context: The context object for the current turn. - token: The original token. - scopes: The scopes to request. - - Returns: - The new token response. - """ - if not self._connection_manager: - logger.error("Connection manager is not configured", stack_info=True) - raise ValueError("Connection manager is not configured") - - auth_handler = self.resolver_handler(handler_id) - if auth_handler.flow is None: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") - - # Use the flow's OBO method to exchange the token - token_provider: AccessTokenProviderBase = ( - self._connection_manager.get_connection(auth_handler.obo_connection_name) - ) - logger.info("Attempting to exchange token on behalf of user") - token = await token_provider.aquire_token_on_behalf_of( - scopes=scopes, - user_assertion=token, - ) - return TokenResponse( - token=token, - scopes=scopes, # Expiration can be set based on the token provider's response - ) - - def get_flow_state(self, auth_handler_id: Optional[str] = None) -> FlowState: - """ - Gets the current state of the OAuth flow. - - Args: - auth_handler_id: Optional ID of the auth handler to check, defaults to first handler. - - Returns: - The flow state object. - """ - flow = self.resolver_handler(auth_handler_id).flow - if flow is None: - # Return a default FlowState if no flow is configured - return FlowState() - - # Return flow state if available - return flow.flow_state or FlowState() - - async def begin_or_continue_flow( - self, - context: TurnContext, - turn_state: TurnState, - auth_handler_id: str, - sec_route: bool = True, - ) -> TokenResponse: - """ - Begins or continues an OAuth flow. - - Args: - context: The context object for the current turn. - state: The state object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - """ - auth_handler = self.resolver_handler(auth_handler_id) - # Get or initialize sign-in state - sign_in_state = turn_state.get_value( - self.SIGN_IN_STATE_KEY, target_cls=SignInState - ) - if sign_in_state is None: - sign_in_state = SignInState( - continuation_activity=None, handler_id=None, completed=False - ) - - flow = auth_handler.flow - if flow is None: - logger.error("OAuth flow is not configured for the auth handler") - raise ValueError("OAuth flow is not configured for the auth handler") - - logger.info( - "Beginning or continuing OAuth flow for handler: %s", auth_handler_id - ) - token_response = await flow.get_user_token(context) - if token_response and token_response.token: - logger.debug("Token obtained successfully") - return token_response - - # Get the current flow state - flow_state = await flow._get_flow_state(context) - - if not flow_state.flow_started: - logger.info("Starting new OAuth flow for handler: %s", auth_handler_id) - token_response = await flow.begin_flow(context) - if sec_route: - sign_in_state.continuation_activity = context.activity - sign_in_state.handler_id = auth_handler_id - turn_state.set_value(self.SIGN_IN_STATE_KEY, sign_in_state) - else: - logger.info( - "Continuing existing OAuth flow for handler: %s", auth_handler_id - ) - token_response = await flow.continue_flow(context) - # Check if sign-in was successful and call handler if configured - if token_response and token_response.token: - if self._sign_in_handler: - logger.info("Sign-in successful, calling sign-in handler") - await self._sign_in_handler(context, turn_state, auth_handler_id) - if sec_route: - turn_state.delete_value(self.SIGN_IN_STATE_KEY) - else: - if self._sign_in_failed_handler: - logger.warning( - "Sign-in failed, calling sign-in failed handler", - stack_info=True, - ) - await self._sign_in_failed_handler( - context, turn_state, auth_handler_id - ) - - await turn_state.save(context) - return token_response - - def resolver_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: - """ - Resolves the auth handler to use based on the provided ID. - - Args: - auth_handler_id: Optional ID of the auth handler to resolve, defaults to first handler. - - Returns: - The resolved auth handler. - """ - if auth_handler_id: - if auth_handler_id not in self._auth_handlers: - logger.error(f"Auth handler '{auth_handler_id}' not found") - raise ValueError(f"Auth handler '{auth_handler_id}' not found") - return self._auth_handlers[auth_handler_id] - - # Return the first handler if no ID specified - first_key = next(iter(self._auth_handlers)) - return self._auth_handlers[first_key] - - async def sign_out( - self, - context: TurnContext, - state: TurnState, - auth_handler_id: Optional[str] = None, - ) -> None: - """ - Signs out the current user. - This method clears the user's token and resets the OAuth state. - - Args: - context: The context object for the current turn. - state: The state object for the current turn. - auth_handler_id: Optional ID of the auth handler to use for sign out. - """ - if auth_handler_id is None: - # Sign out from all handlers - for handler_key, auth_handler in self._auth_handlers.items(): - if auth_handler.flow: - logger.info(f"Signing out from handler: {handler_key}") - await auth_handler.flow.sign_out(context) - else: - # Sign out from specific handler - auth_handler = self.resolver_handler(auth_handler_id) - if auth_handler.flow: - logger.info(f"Signing out from handler: {auth_handler_id}") - await auth_handler.flow.sign_out(context) - - def on_sign_in_success( - self, - handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], - ) -> None: - """ - Sets a handler to be called when sign-in is successfully completed. - - Args: - handler: The handler function to call on successful sign-in. - """ - self._sign_in_handler = handler - - def on_sign_in_failure( - self, - handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], - ) -> None: - """ - Sets a handler to be called when sign-in fails. - Args: - handler: The handler function to call on sign-in failure. - """ - self._sign_in_failed_handler = handler +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations +import logging +import jwt +from typing import Dict, Optional, Callable, Awaitable, AsyncIterator +from collections.abc import Iterable +from contextlib import asynccontextmanager + +from microsoft.agents.hosting.core.authorization import ( + Connections, + AccessTokenProviderBase, +) +from microsoft.agents.hosting.core.storage import Storage, MemoryStorage +from microsoft.agents.activity import TokenResponse +from microsoft.agents.hosting.core.connector.client import UserTokenClient + +from ...turn_context import TurnContext +from ...oauth import OAuthFlow, FlowResponse, FlowState, FlowStateTag, FlowStorageClient +from ..state.turn_state import TurnState +from .auth_handler import AuthHandler + +logger = logging.getLogger(__name__) + + +class Authorization: + """ + Class responsible for managing authorization and OAuth flows. + Handles multiple OAuth providers and manages the complete authentication lifecycle. + """ + + def __init__( + self, + storage: Storage, + connection_manager: Connections, + auth_handlers: dict[str, AuthHandler] = None, + auto_signin: bool = None, + use_cache: bool = False, + **kwargs, + ): + """ + Creates a new instance of Authorization. + + Args: + storage: The storage system to use for state management. + auth_handlers: Configuration for OAuth providers. + + Raises: + ValueError: If storage is None or no auth handlers are provided. + """ + if not storage: + raise ValueError("Storage is required for Authorization") + + self._storage = storage + self._connection_manager = connection_manager + + auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( + "USERAUTHORIZATION", {} + ) + + handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") + if not auth_handlers and handlers_config: + auth_handlers = { + handler_name: AuthHandler( + name=handler_name, **config.get("SETTINGS", {}) + ) + for handler_name, config in handlers_config.items() + } + + self._auth_handlers = auth_handlers or {} + self._sign_in_success_handler: Optional[ + Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] + ] = lambda *args: None + self._sign_in_failure_handler: Optional[ + Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] + ] = lambda *args: None + + def _ids_from_context(self, context: TurnContext) -> tuple[str, str]: + """Checks and returns IDs necessary to load a new or existing flow. + + Raises a ValueError if channel ID or user ID are missing. + """ + if ( + not context.activity.channel_id + or not context.activity.from_property + or not context.activity.from_property.id + ): + raise ValueError("Channel ID and User ID are required") + + return context.activity.channel_id, context.activity.from_property.id + + async def _load_flow( + self, context: TurnContext, auth_handler_id: str = "" + ) -> tuple[OAuthFlow, FlowStorageClient]: + """Loads the OAuth flow for a specific auth handler. + + Args: + context: The context object for the current turn. + auth_handler_id: The ID of the auth handler to use. + + Returns: + The OAuthFlow returned corresponds to the flow associated with the + chosen handler, and the channel and user info found in the context. + The FlowStorageClient corresponds to the same channel and user info. + """ + user_token_client: UserTokenClient = context.turn_state.get( + context.adapter.USER_TOKEN_CLIENT_KEY + ) + + # resolve handler id + auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) + auth_handler_id = auth_handler.name + + channel_id, user_id = self._ids_from_context(context) + + ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims[ + "aud" + ] + + # try to load existing state + flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) + logger.info("Loading OAuth flow state from storage") + flow_state: FlowState = await flow_storage_client.read(auth_handler_id) + + if not flow_state: + logger.info("No existing flow state found, creating new flow state") + flow_state = FlowState( + channel_id=channel_id, + user_id=user_id, + auth_handler_id=auth_handler_id, + connection=auth_handler.abs_oauth_connection_name, + ms_app_id=ms_app_id, + ) + await flow_storage_client.write(flow_state) + + flow = OAuthFlow(flow_state, user_token_client) + return flow, flow_storage_client + + @asynccontextmanager + async def open_flow( + self, context: TurnContext, auth_handler_id: str = "" + ) -> AsyncIterator[OAuthFlow]: + """Loads an OAuth flow and saves changes the changes to storage if any are made. + + Args: + context: The context object for the current turn. + auth_handler_id: ID of the auth handler to use. + If none provided, uses the first handler. + + Yields: + OAuthFlow: + The OAuthFlow instance loaded from storage or newly created + if not yet present in storage. + """ + if not context: + logger.error("No context provided to open_flow") + raise ValueError("context is required") + + flow, flow_storage_client = await self._load_flow(context, auth_handler_id) + yield flow + logger.info("Saving OAuth flow state to storage") + await flow_storage_client.write(flow.flow_state) + + async def get_token( + self, context: TurnContext, auth_handler_id: str + ) -> TokenResponse: + """ + Gets the token for a specific auth handler. + + Args: + context: The context object for the current turn. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The token response from the OAuth provider. + """ + logger.info("Getting token for auth handler: %s", auth_handler_id) + async with self.open_flow(context, auth_handler_id) as flow: + return await flow.get_user_token() + + async def exchange_token( + self, + context: TurnContext, + scopes: list[str], + auth_handler_id: Optional[str] = None, + ) -> TokenResponse: + """ + Exchanges a token for another token with different scopes. + + Args: + context: The context object for the current turn. + scopes: The scopes to request for the new token. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The token response from the OAuth provider. + """ + logger.info("Exchanging token for scopes: %s", scopes) + async with self.open_flow(context, auth_handler_id) as flow: + token_response = await flow.get_user_token() + + if token_response and self._is_exchangeable(token_response.token): + logger.debug("Token is exchangeable, performing OBO flow") + return await self._handle_obo(token_response.token, scopes, auth_handler_id) + + return TokenResponse() + + def _is_exchangeable(self, token: str) -> bool: + """ + Checks if a token is exchangeable (has api:// audience). + + Args: + token: The token to check. + + Returns: + True if the token is exchangeable, False otherwise. + """ + try: + # Decode without verification to check the audience + payload = jwt.decode(token, options={"verify_signature": False}) + aud = payload.get("aud") + return isinstance(aud, str) and aud.startswith("api://") + except Exception: + logger.error("Failed to decode token to check audience") + return False + + async def _handle_obo( + self, token: str, scopes: list[str], handler_id: str = None + ) -> TokenResponse: + """ + Handles On-Behalf-Of token exchange. + + Args: + context: The context object for the current turn. + token: The original token. + scopes: The scopes to request. + + Returns: + The new token response. + + """ + auth_handler = self.resolve_handler(handler_id) + token_provider: AccessTokenProviderBase = ( + self._connection_manager.get_connection(auth_handler.obo_connection_name) + ) + + logger.info("Attempting to exchange token on behalf of user") + new_token = await token_provider.aquire_token_on_behalf_of( + scopes=scopes, + user_assertion=token, + ) + return TokenResponse( + token=new_token, + scopes=scopes, # Expiration can be set based on the token provider's response + ) + + async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: + """Gets the first active flow state for the current context.""" + logger.debug("Getting active flow state") + channel_id, user_id = self._ids_from_context(context) + flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) + for auth_handler_id in self._auth_handlers.keys(): + flow_state = await flow_storage_client.read(auth_handler_id) + if flow_state and flow_state.is_active(): + return flow_state + return None + + async def begin_or_continue_flow( + self, + context: TurnContext, + turn_state: TurnState, + auth_handler_id: str = "", + ) -> FlowResponse: + """Begins or continues an OAuth flow. + + Args: + context: The context object for the current turn. + turn_state: The state object for the current turn. + auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + + Returns: + The token response from the OAuth provider. + + """ + if not auth_handler_id: + auth_handler_id = self.resolve_handler().name + + logger.debug("Beginning or continuing OAuth flow") + async with self.open_flow(context, auth_handler_id) as flow: + prev_tag = flow.flow_state.tag + flow_response: FlowResponse = await flow.begin_or_continue_flow( + context.activity + ) + + flow_state: FlowState = flow_response.flow_state + + if ( + flow_state.tag == FlowStateTag.COMPLETE + and prev_tag != FlowStateTag.COMPLETE + ): + logger.debug("Calling Authorization sign in success handler") + self._sign_in_success_handler( + context, turn_state, flow_state.auth_handler_id + ) + elif flow_state.tag == FlowStateTag.FAILURE: + logger.debug("Calling Authorization sign in failure handler") + self._sign_in_failure_handler( + context, + turn_state, + flow_state.auth_handler_id, + flow_response.flow_error_tag, + ) + + return flow_response + + def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: + """Resolves the auth handler to use based on the provided ID. + + Args: + auth_handler_id: Optional ID of the auth handler to resolve, defaults to first handler. + + Returns: + The resolved auth handler. + """ + if auth_handler_id: + if auth_handler_id not in self._auth_handlers: + logger.error("Auth handler '%s' not found", auth_handler_id) + raise ValueError(f"Auth handler '{auth_handler_id}' not found") + return self._auth_handlers[auth_handler_id] + + # Return the first handler if no ID specified + return next(iter(self._auth_handlers.values())) + + async def _sign_out( + self, + context: TurnContext, + auth_handler_ids: Iterable[str], + ) -> None: + """Signs out from the specified auth handlers. + + Args: + context: The context object for the current turn. + auth_handler_ids: Iterable of auth handler IDs to sign out from. + + Deletes the associated flow states from storage. + """ + for auth_handler_id in auth_handler_ids: + flow, flow_storage_client = await self._load_flow(context, auth_handler_id) + # ensure that the id is valid + self.resolve_handler(auth_handler_id) + logger.info("Signing out from handler: %s", auth_handler_id) + await flow.sign_out() + await flow_storage_client.delete(auth_handler_id) + + async def sign_out( + self, + context: TurnContext, + auth_handler_id: Optional[str] = None, + ) -> None: + """ + Signs out the current user. + This method clears the user's token and resets the OAuth state. + + Args: + context: The context object for the current turn. + auth_handler_id: Optional ID of the auth handler to use for sign out. If None, + signs out from all the handlers. + + Deletes the associated flow state(s) from storage. + """ + if auth_handler_id: + await self._sign_out(context, [auth_handler_id]) + else: + await self._sign_out(context, self._auth_handlers.keys()) + + def on_sign_in_success( + self, + handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], + ) -> None: + """ + Sets a handler to be called when sign-in is successfully completed. + + Args: + handler: The handler function to call on successful sign-in. + """ + self._sign_in_success_handler = handler + + def on_sign_in_failure( + self, + handler: Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]], + ) -> None: + """ + Sets a handler to be called when sign-in fails. + Args: + handler: The handler function to call on sign-in failure. + """ + self._sign_in_failure_handler = handler diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py new file mode 100644 index 00000000..79858343 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/__init__.py @@ -0,0 +1,12 @@ +from .flow_state import FlowState, FlowStateTag, FlowErrorTag +from .flow_storage_client import FlowStorageClient +from .oauth_flow import OAuthFlow, FlowResponse + +__all__ = [ + "FlowState", + "FlowStateTag", + "FlowErrorTag", + "FlowResponse", + "FlowStorageClient", + "OAuthFlow", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py new file mode 100644 index 00000000..1ac105df --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_state.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from microsoft.agents.activity import Activity + +from ..storage import StoreItem + + +class FlowStateTag(Enum): + """Represents the top-level state of an OAuthFlow + + For instance, a flow can arrive at an error, but its + broader state may still be CONTINUE if the flow can + still progress + """ + + BEGIN = "begin" + CONTINUE = "continue" + NOT_STARTED = "not_started" + FAILURE = "failure" + COMPLETE = "complete" + + +class FlowErrorTag(Enum): + """Represents the various error states that can occur during an OAuthFlow""" + + NONE = "none" + MAGIC_FORMAT = "magic_format" + MAGIC_CODE_INCORRECT = "magic_code_incorrect" + OTHER = "other" + + +class FlowState(BaseModel, StoreItem): + """Represents the state of an OAuthFlow""" + + user_token: str = "" + + channel_id: str = "" + user_id: str = "" + ms_app_id: str = "" + connection: str = "" + auth_handler_id: str = "" + + expiration: float = 0 + continuation_activity: Optional[Activity] = None + attempts_remaining: int = 0 + tag: FlowStateTag = FlowStateTag.NOT_STARTED + + def store_item_to_json(self) -> dict: + return self.model_dump(mode="json", exclude_unset=True, by_alias=True) + + @staticmethod + def from_json_to_store_item(json_data: dict) -> "FlowState": + return FlowState.model_validate(json_data) + + def is_expired(self) -> bool: + return datetime.now().timestamp() >= self.expiration + + def reached_max_attempts(self) -> bool: + return self.attempts_remaining <= 0 + + def is_active(self) -> bool: + return ( + not self.is_expired() + and not self.reached_max_attempts() + and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] + ) + + def refresh(self): + if ( + self.tag + in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE, FlowStateTag.COMPLETE] + and self.is_expired() + ): + self.tag = FlowStateTag.NOT_STARTED diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py new file mode 100644 index 00000000..7ab03879 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/flow_storage_client.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import Optional + +from ..storage import Storage +from .flow_state import FlowState + + +class DummyCache(Storage): + + async def read(self, keys: list[str], **kwargs) -> dict[str, FlowState]: + return {} + + async def write(self, changes: dict[str, FlowState]) -> None: + pass + + async def delete(self, keys: list[str]) -> None: + pass + + +# this could be generalized. Ideas: +# - CachedStorage class for two-tier storage +# - Namespaced/PrefixedStorage class for namespacing keying +# not generally thread or async safe (operations are not atomic) +class FlowStorageClient: + """Wrapper around Storage that manages sign-in state specific to each user and channel. + + Uses the activity's channel_id and from.id to create a key prefix for storage operations. + """ + + def __init__( + self, + channel_id: str, + user_id: str, + storage: Storage, + cache_class: type[Storage] = None, + ): + """ + Args: + channel_id: used to create the prefix + user_id: used to create the prefix + storage: the backing storage + cache_class: the cache class to use (defaults to DummyCache, which performs no caching). + This cache's lifetime is tied to the FlowStorageClient instance. + """ + + if not user_id or not channel_id: + raise ValueError( + "FlowStorageClient.__init__(): channel_id and user_id must be set." + ) + + self._base_key = f"auth/{channel_id}/{user_id}/" + self._storage = storage + if cache_class is None: + cache_class = DummyCache + self._cache = cache_class() + + @property + def base_key(self) -> str: + """Returns the prefix used for flow state storage isolation.""" + return self._base_key + + def key(self, auth_handler_id: str) -> str: + """Creates a storage key for a specific sign-in handler.""" + return f"{self._base_key}{auth_handler_id}" + + async def read(self, auth_handler_id: str) -> Optional[FlowState]: + """Reads the flow state for a specific authentication handler.""" + key: str = self.key(auth_handler_id) + data = await self._cache.read([key], target_cls=FlowState) + if key not in data: + data = await self._storage.read([key], target_cls=FlowState) + if key not in data: + return None + await self._cache.write({key: data[key]}) + return FlowState.model_validate(data.get(key)) + + async def write(self, value: FlowState) -> None: + """Saves the flow state for a specific authentication handler.""" + key: str = self.key(value.auth_handler_id) + cached_state = await self._cache.read([key], target_cls=FlowState) + if not cached_state or cached_state != value: + await self._cache.write({key: value}) + await self._storage.write({key: value}) + + async def delete(self, auth_handler_id: str) -> None: + """Deletes the flow state for a specific authentication handler.""" + key: str = self.key(auth_handler_id) + cached_state = await self._cache.read([key], target_cls=FlowState) + if cached_state: + await self._cache.delete([key]) + await self._storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py new file mode 100644 index 00000000..b5f0d5bb --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth/oauth_flow.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations + +import logging + +from pydantic import BaseModel +from datetime import datetime +from typing import Optional + +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + TokenExchangeState, + TokenResponse, + SignInResource, +) + +from ..connector.client import UserTokenClient +from .flow_state import FlowState, FlowStateTag, FlowErrorTag + +logger = logging.getLogger(__name__) + + +class FlowResponse(BaseModel): + """Represents the response for a flow operation.""" + + flow_state: FlowState = FlowState() + flow_error_tag: FlowErrorTag = FlowErrorTag.NONE + token_response: Optional[TokenResponse] = None + sign_in_resource: Optional[SignInResource] = None + continuation_activity: Optional[Activity] = None + + +class OAuthFlow: + """ + Manages the OAuth flow. + + This class is responsible for managing the entire OAuth flow, including + obtaining user tokens, signing out users, and handling token exchanges. + + Contract with other classes (usage of other classes is enforced in unit tests): + TurnContext.activity.channel_id + TurnContext.activity.from_property.id + + UserTokenClient: user_token.get_token(), user_token.sign_out() + """ + + def __init__( + self, flow_state: FlowState, user_token_client: UserTokenClient, **kwargs + ): + """ + Arguments: + flow_state: The state of the flow. + user_token_client: The user token client to use for token operations. + + Keyword Arguments: + flow_duration: The duration of the flow in milliseconds (default: 60000). + max_attempts: The maximum number of attempts for the flow + set when starting a flow (default: 3). + """ + if not flow_state or not user_token_client: + raise ValueError( + "OAuthFlow.__init__(): flow_state and user_token_client are required" + ) + + if ( + not flow_state.connection + or not flow_state.ms_app_id + or not flow_state.channel_id + or not flow_state.user_id + ): + raise ValueError( + "OAuthFlow.__init__: flow_state must have ms_app_id, channel_id, user_id, connection defined" + ) + + logger.debug("Initializing OAuthFlow with flow state: %s", flow_state) + + self._flow_state = flow_state.model_copy() + + self._abs_oauth_connection_name = self._flow_state.connection + self._ms_app_id = self._flow_state.ms_app_id + self._channel_id = self._flow_state.channel_id + self._user_id = self._flow_state.user_id + + self._user_token_client = user_token_client + + self._default_flow_duration = kwargs.get( + "default_flow_duration", 10 * 60 + ) # default to 10 minutes + self._max_attempts = kwargs.get("max_attempts", 3) # defaults to 3 max attempts + + logger.debug( + "OAuthFlow initialized with connection: %s, ms_app_id: %s, channel_id: %s, user_id: %s", + self._abs_oauth_connection_name, + self._ms_app_id, + self._channel_id, + self._user_id, + ) + logger.debug( + "Default flow duration: %d ms, Max attempts: %d", + self._default_flow_duration, + self._max_attempts, + ) + + @property + def flow_state(self) -> FlowState: + return self._flow_state.model_copy() + + async def get_user_token(self, magic_code: str = None) -> TokenResponse: + """Get the user token based on the context. + + Args: + magic_code (str, optional): Defaults to None. The magic code for user authentication. + + Returns: + TokenResponse + The user token response. + + Notes: + flow_state.user_token is updated with the latest token. + """ + logger.info( + "Getting user token for user_id: %s, connection: %s", + self._user_id, + self._abs_oauth_connection_name, + ) + token_response: TokenResponse = ( + await self._user_token_client.user_token.get_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + code=magic_code, + ) + ) + if token_response: + logger.info("User token obtained successfully: %s", token_response) + self._flow_state.user_token = token_response.token + self._flow_state.expiration = ( + datetime.now().timestamp() + self._default_flow_duration + ) + + return token_response + + async def sign_out(self) -> None: + """Sign out the user. + + Sets the flow state tag to NOT_STARTED + Resets the flow state user_token field + """ + logger.info( + "Signing out user_id: %s from connection: %s", + self._user_id, + self._abs_oauth_connection_name, + ) + await self._user_token_client.user_token.sign_out( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + ) + self._flow_state.user_token = "" + self._flow_state.tag = FlowStateTag.NOT_STARTED + + def _use_attempt(self) -> None: + """Decrements the remaining attempts for the flow, checking for failure.""" + self._flow_state.attempts_remaining -= 1 + if self._flow_state.attempts_remaining <= 0: + self._flow_state.tag = FlowStateTag.FAILURE + logger.debug( + "Using an attempt for the OAuth flow. Attempts remaining after use: %d", + self._flow_state.attempts_remaining, + ) + + async def begin_flow(self, activity: Activity) -> FlowResponse: + """Begins the OAuthFlow. + + Args: + activity: The activity that initiated the flow. + + Returns: + The response containing the flow state and sign-in resource if applicable. + + Notes: + The flow state is reset if a token is not obtained from cache. + """ + token_response = await self.get_user_token() + if token_response: + return FlowResponse( + flow_state=self._flow_state, token_response=token_response + ) + + logger.debug("Starting new OAuth flow") + self._flow_state.tag = FlowStateTag.BEGIN + self._flow_state.expiration = ( + datetime.now().timestamp() + self._default_flow_duration + ) + + self._flow_state.attempts_remaining = self._max_attempts + self._flow_state.user_token = "" + self._flow_state.continuation_activity = activity.model_copy() + + token_exchange_state = TokenExchangeState( + connection_name=self._abs_oauth_connection_name, + conversation=activity.get_conversation_reference(), + relates_to=activity.relates_to, + ms_app_id=self._ms_app_id, + ) + + sign_in_resource = ( + await self._user_token_client.agent_sign_in.get_sign_in_resource( + state=token_exchange_state.get_encoded_state() + ) + ) + + logger.debug("Sign-in resource obtained successfully: %s", sign_in_resource) + + return FlowResponse( + flow_state=self._flow_state, sign_in_resource=sign_in_resource + ) + + async def _continue_from_message( + self, activity: Activity + ) -> tuple[TokenResponse, FlowErrorTag]: + """Handles the continuation of the flow from a message activity.""" + magic_code: str = activity.text + if magic_code and magic_code.isdigit() and len(magic_code) == 6: + token_response: TokenResponse = await self.get_user_token(magic_code) + + if token_response: + return token_response, FlowErrorTag.NONE + else: + return token_response, FlowErrorTag.MAGIC_CODE_INCORRECT + else: + return TokenResponse(), FlowErrorTag.MAGIC_FORMAT + + async def _continue_from_invoke_verify_state( + self, activity: Activity + ) -> TokenResponse: + """Handles the continuation of the flow from an invoke activity for verifying state.""" + token_verify_state = activity.value + magic_code: str = token_verify_state.get("state") + token_response: TokenResponse = await self.get_user_token(magic_code) + return token_response + + async def _continue_from_invoke_token_exchange( + self, activity: Activity + ) -> TokenResponse: + """Handles the continuation of the flow from an invoke activity for token exchange.""" + token_exchange_request = activity.value + token_response = await self._user_token_client.user_token.exchange_token( + user_id=self._user_id, + connection_name=self._abs_oauth_connection_name, + channel_id=self._channel_id, + body=token_exchange_request, + ) + return token_response + + async def continue_flow(self, activity: Activity) -> FlowResponse: + """Continues the OAuth flow based on the incoming activity. + + Args: + activity: The incoming activity to continue the flow with. + + Returns: + A FlowResponse object containing the updated flow state and any token response. + + """ + logger.debug("Continuing auth flow...") + + if not self._flow_state.is_active(): + logger.debug("OAuth flow is not active, cannot continue") + self._flow_state.tag = FlowStateTag.FAILURE + return FlowResponse( + flow_state=self._flow_state.model_copy(), token_response=None + ) + + flow_error_tag = FlowErrorTag.NONE + if activity.type == ActivityTypes.message: + token_response, flow_error_tag = await self._continue_from_message(activity) + elif ( + activity.type == ActivityTypes.invoke + and activity.name == "signin/verifyState" + ): + token_response = await self._continue_from_invoke_verify_state(activity) + elif ( + activity.type == ActivityTypes.invoke + and activity.name == "signin/tokenExchange" + ): + token_response = await self._continue_from_invoke_token_exchange(activity) + else: + raise ValueError(f"Unknown activity type {activity.type}") + + if not token_response and flow_error_tag == FlowErrorTag.NONE: + flow_error_tag = FlowErrorTag.OTHER + + if flow_error_tag != FlowErrorTag.NONE: + logger.debug("Flow error occurred: %s", flow_error_tag) + self._flow_state.tag = FlowStateTag.CONTINUE + self._use_attempt() + else: + self._flow_state.tag = FlowStateTag.COMPLETE + self._flow_state.expiration = ( + datetime.now().timestamp() + self._default_flow_duration + ) + self._flow_state.user_token = token_response.token + logger.debug( + "OAuth flow completed successfully, got TokenResponse: %s", + token_response, + ) + + return FlowResponse( + flow_state=self._flow_state.model_copy(), + flow_error_tag=flow_error_tag, + token_response=token_response, + continuation_activity=self._flow_state.continuation_activity, + ) + + async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: + """Begins a new OAuth flow or continues an existing one based on the activity. + + Args: + activity: The incoming activity to begin or continue the flow with. + + Returns: + A FlowResponse object containing the updated flow state and any token response. + """ + self._flow_state.refresh() + if self._flow_state.tag == FlowStateTag.COMPLETE: # robrandao: TODO -> test + logger.debug("OAuth flow has already been completed, nothing to do") + return FlowResponse( + flow_state=self._flow_state.model_copy(), + token_response=TokenResponse(token=self._flow_state.user_token), + ) + + if self._flow_state.is_active(): + logger.debug("Active flow, continuing...") + return await self.continue_flow(activity) + + logger.debug("No active flow, beginning new flow...") + return await self.begin_flow(activity) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py deleted file mode 100644 index b54a4f75..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/oauth_flow.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from __future__ import annotations - -from datetime import datetime -from typing import Dict, Optional - -from microsoft.agents.hosting.core.connector.client import UserTokenClient -from microsoft.agents.activity import ( - ActionTypes, - ActivityTypes, - CardAction, - Attachment, - OAuthCard, - TokenExchangeState, - TokenResponse, - Activity, -) -from microsoft.agents.activity import ( - TurnContextProtocol as TurnContext, -) -from microsoft.agents.hosting.core.storage import StoreItem, Storage -from pydantic import BaseModel - -from .message_factory import MessageFactory -from .card_factory import CardFactory - - -class FlowState(StoreItem, BaseModel): - flow_started: bool = False - user_token: str = "" - flow_expires: float = 0 - abs_oauth_connection_name: Optional[str] = None - continuation_activity: Optional[Activity] = None - - def store_item_to_json(self) -> dict: - return self.model_dump() - - @staticmethod - def from_json_to_store_item(json_data: dict) -> "StoreItem": - return FlowState.model_validate(json_data) - - -class OAuthFlow: - """ - Manages the OAuth flow. - """ - - def __init__( - self, - storage: Storage, - abs_oauth_connection_name: str, - user_token_client: Optional[UserTokenClient] = None, - messages_configuration: dict[str, str] = None, - **kwargs, - ): - """ - Creates a new instance of OAuthFlow. - - Args: - user_state: The user state. - abs_oauth_connection_name: The OAuth connection name. - user_token_client: Optional user token client. - messages_configuration: Optional messages configuration for backward compatibility. - """ - if not abs_oauth_connection_name: - raise ValueError( - "OAuthFlow.__init__: connectionName expected but not found" - ) - - # Handle backward compatibility with messages_configuration - self.messages_configuration = messages_configuration or {} - - # Initialize properties - self.abs_oauth_connection_name = abs_oauth_connection_name - self.user_token_client = user_token_client - self.token_exchange_id: Optional[str] = None - - # Initialize state and flow state - self._storage = storage - self.flow_state = None - - async def get_user_token(self, context: TurnContext) -> TokenResponse: - """ - Retrieves the user token from the user token service. - - Args: - context: The turn context containing the activity information. - - Returns: - The user token response. - - Raises: - ValueError: If the channelId or from properties are not set in the activity. - """ - await self._initialize_token_client(context) - - if not context.activity.from_property: - raise ValueError("User ID is not set in the activity.") - - if not context.activity.channel_id: - raise ValueError("Channel ID is not set in the activity.") - - return await self.user_token_client.user_token.get_token( - user_id=context.activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=context.activity.channel_id, - ) - - async def begin_flow(self, context: TurnContext) -> TokenResponse: - """ - Begins the OAuth flow. - - Args: - context: The turn context. - - Returns: - A TokenResponse object. - """ - self.flow_state = FlowState() - - if not self.abs_oauth_connection_name: - raise ValueError("connectionName is not set") - - await self._initialize_token_client(context) - - activity = context.activity - - # Try to get existing token first - user_token = await self.user_token_client.user_token.get_token( - user_id=activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=activity.channel_id, - ) - - if user_token and user_token.token: - # Already have token, return it - self.flow_state.flow_started = False - self.flow_state.flow_expires = 0 - self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name - await self._save_flow_state(context) - return user_token - - # No token, need to start sign-in flow - token_exchange_state = TokenExchangeState( - connection_name=self.abs_oauth_connection_name, - conversation=activity.get_conversation_reference(), - relates_to=activity.relates_to, - ms_app_id=context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims[ - "aud" - ], - ) - - sign_in_resource = ( - await self.user_token_client.agent_sign_in.get_sign_in_resource( - state=token_exchange_state.get_encoded_state(), - ) - ) - - # Create the OAuth card - o_card: Attachment = CardFactory.oauth_card( - OAuthCard( - text=self.messages_configuration.get("card_title", "Sign in"), - connection_name=self.abs_oauth_connection_name, - buttons=[ - CardAction( - title=self.messages_configuration.get("button_text", "Sign in"), - type=ActionTypes.signin, - value=sign_in_resource.sign_in_link, - channel_data=None, - ) - ], - token_exchange_resource=sign_in_resource.token_exchange_resource, - token_post_resource=sign_in_resource.token_post_resource, - ) - ) - - # Send the card to the user - await context.send_activity(MessageFactory.attachment(o_card)) - - # Update flow state - self.flow_state.flow_started = True - self.flow_state.flow_expires = datetime.now().timestamp() + 30000 - self.flow_state.abs_oauth_connection_name = self.abs_oauth_connection_name - await self._save_flow_state(context) - - # Return in-progress response - return TokenResponse() - - async def continue_flow(self, context: TurnContext) -> TokenResponse: - """ - Continues the OAuth flow. - - Args: - context: The turn context. - - Returns: - A TokenResponse object. - """ - await self._initialize_token_client(context) - - if ( - self.flow_state - and self.flow_state.flow_expires != 0 - and datetime.now().timestamp() > self.flow_state.flow_expires - ): - await context.send_activity( - MessageFactory.text( - self.messages_configuration.get( - "session_expired_messages", - "Sign-in session expired. Please try again.", - ) - ) - ) - return TokenResponse() - - cont_flow_activity = context.activity - - # Handle message type activities (typically when the user enters a code) - if cont_flow_activity.type == ActivityTypes.message: - magic_code = cont_flow_activity.text - - # Validate magic code format (6 digits) - if magic_code and magic_code.isdigit() and len(magic_code) == 6: - result = await self.user_token_client.user_token.get_token( - user_id=cont_flow_activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=cont_flow_activity.channel_id, - code=magic_code, - ) - - if result and result.token: - self.flow_state.flow_started = False - self.flow_state.flow_expires = 0 - self.flow_state.abs_oauth_connection_name = ( - self.abs_oauth_connection_name - ) - await self._save_flow_state(context) - return result - else: - await context.send_activity( - MessageFactory.text("Invalid code. Please try again.") - ) - self.flow_state.flow_started = True - self.flow_state.flow_expires = datetime.now().timestamp() + 30000 - await self._save_flow_state(context) - return TokenResponse() - else: - await context.send_activity( - MessageFactory.text( - "Invalid code format. Please enter a 6-digit code." - ) - ) - return TokenResponse() - - # Handle verify state invoke activity - if ( - cont_flow_activity.type == ActivityTypes.invoke - and cont_flow_activity.name == "signin/verifyState" - ): - token_verify_state = cont_flow_activity.value - magic_code = token_verify_state.get("state") - - result = await self.user_token_client.user_token.get_token( - user_id=cont_flow_activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=cont_flow_activity.channel_id, - code=magic_code, - ) - - if result and result.token: - self.flow_state.flow_started = False - self.flow_state.abs_oauth_connection_name = ( - self.abs_oauth_connection_name - ) - await self._save_flow_state(context) - return result - return TokenResponse() - - # Handle token exchange invoke activity - if ( - cont_flow_activity.type == ActivityTypes.invoke - and cont_flow_activity.name == "signin/tokenExchange" - ): - token_exchange_request = cont_flow_activity.value - - # Dedupe checks to prevent duplicate processing - token_exchange_id = token_exchange_request.get("id") - if self.token_exchange_id == token_exchange_id: - # Already processed this request - return TokenResponse() - - # Store this request ID - self.token_exchange_id = token_exchange_id - - # Exchange the token - user_token_resp = await self.user_token_client.user_token.exchange_token( - user_id=cont_flow_activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=cont_flow_activity.channel_id, - body=token_exchange_request, - ) - - if user_token_resp and user_token_resp.token: - self.flow_state.flow_started = False - await self._save_flow_state(context) - return user_token_resp - else: - self.flow_state.flow_started = True - return TokenResponse() - - return TokenResponse() - - async def sign_out(self, context: TurnContext) -> None: - """ - Signs the user out. - - Args: - context: The turn context. - """ - await self._initialize_token_client(context) - - await self.user_token_client.user_token.sign_out( - user_id=context.activity.from_property.id, - connection_name=self.abs_oauth_connection_name, - channel_id=context.activity.channel_id, - ) - - if self.flow_state: - self.flow_state.flow_expires = 0 - await self._save_flow_state(context) - - async def _get_flow_state(self, context: TurnContext) -> FlowState: - """ - Gets the user state. - - Args: - context: The turn context. - - Returns: - The user state. - """ - storage_key = self._get_storage_key(context) - - storage_result: Dict[str, FlowState] | None = await self._storage.read( - [storage_key], target_cls=FlowState - ) - if not storage_result or storage_key not in storage_result: - return FlowState() - return storage_result[storage_key] - - async def _save_flow_state(self, context: TurnContext) -> None: - """ - Saves the flow state to the user state. - Args: - context: The turn context. - """ - await self._storage.write({self._get_storage_key(context): self.flow_state}) - - async def _initialize_token_client(self, context: TurnContext) -> None: - """ - Initializes the user token client if not already set. - - Args: - context: The turn context. - """ - - # TODO: Change this to caching when the story is implemented, for now we're getting it from TurnContext (new with every request) - self.user_token_client = context.turn_state.get( - context.adapter.USER_TOKEN_CLIENT_KEY - ) - - def _get_storage_key(self, context: TurnContext) -> str: - """ - Gets the storage key for the flow state. - - Args: - context: The turn context. - - Returns: - The storage key. - """ - channel_id = context.activity.channel_id - if not channel_id: - raise ValueError("Channel ID is not set in the activity.") - user_id = ( - context.activity.from_property.id - if context.activity.from_property - else None - ) - if not user_id: - raise ValueError("User ID is not set in the activity.") - - return ( - f"oauth/{self.abs_oauth_connection_name}/{channel_id}/{user_id}/flowState" - ) diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py index e095cbd6..280f63e7 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py @@ -129,6 +129,7 @@ async def equals(self, other) -> bool: for key in self._key_history: if key not in self._memory: if len(await other.read([key], target_cls=MockStoreItem)) > 0: + breakpoint() return False # key should not exist in other continue @@ -138,6 +139,7 @@ async def equals(self, other) -> bool: res = await other.read([key], target_cls=target_cls) if key not in res or item != res[key]: + breakpoint() return False return True diff --git a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py index e7ee8b69..d80f2a9a 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_authorization.py @@ -1,213 +1,607 @@ -import pytest -from .tools.testing_authorization import ( - TestingAuthorization, - create_test_auth_handler, -) -from .tools.testing_utility import TestingUtility -import jwt -from unittest.mock import Mock, AsyncMock -from microsoft.agents.hosting.core import SignInState -from microsoft.agents.hosting.core.oauth_flow import FlowState - - -class TestAuthorization: - def setup_method(self): - self.turn_context = TestingUtility.create_empty_context() - - @pytest.mark.asyncio - async def test_get_token_single_handler(self): - """ - Test Authorization - get_token() with single Auth Handler - """ - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - } - ) - - token_res = await auth.get_token(self.turn_context) - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - @pytest.mark.asyncio - async def test_get_token_multiple_handlers(self): - """ - Test Authorization - get_token() with multiple Auth Handlers - """ - auth_handlers = { - "auth-handler": create_test_auth_handler("test-auth-a"), - "auth-handler-obo": create_test_auth_handler("test-auth-b", obo=True), - "auth-handler-with-title": create_test_auth_handler( - "test-auth-c", title="test-title" - ), - "auth-handler-with-title-text": create_test_auth_handler( - "test-auth-d", title="test-title", text="test-text" - ), - } - auth = TestingAuthorization(auth_handlers=auth_handlers) - for id, auth_handler in auth_handlers.items(): - # test value propogation - token_res = await auth.get_token(self.turn_context, id) - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - @pytest.mark.asyncio - async def test_exchange_token_valid_token(self): - valid_token = jwt.encode({"aud": "api://botframework.test.api"}, "") - scopes = ["scope-a"] - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth", obo=True), - }, - token=valid_token, - ) - token_res = await auth.exchange_token(self.turn_context, scopes=scopes) - assert ( - token_res.token - == f"{auth.resolver_handler().obo_connection_name}-obo-token" - ) - - @pytest.mark.asyncio - async def test_exchange_token_invalid_token(self): - invalid_token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - scopes = ["scope-a"] - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth"), - }, - token=invalid_token, - ) - token_res = await auth.exchange_token(self.turn_context, scopes=scopes) - assert token_res.token == invalid_token - - @pytest.mark.asyncio - async def test_get_flow_state_unavailable(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - } - ) - - assert auth.get_flow_state() == FlowState() - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_not_started(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.begin_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.set_value.assert_called_once_with( - auth.SIGN_IN_STATE_KEY, - SignInState( - continuation_activity=self.turn_context.activity, - handler_id="auth-handler", - ), - ) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_started(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - flow_started=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_started_sign_in_success(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - flow_started=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - auth.on_sign_in_success(AsyncMock()) - - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert token_res.connection_name == auth_handler.abs_oauth_connection_name - assert token_res.token == f"{auth_handler.abs_oauth_connection_name}-token" - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - mock_turn_state.delete_value.assert_called_once_with(auth.SIGN_IN_STATE_KEY) - auth._sign_in_handler.assert_called_once_with( - self.turn_context, mock_turn_state, "auth-handler" - ) - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_started_sign_in_failure(self): - auth = TestingAuthorization( - auth_handlers={ - "auth-handler": create_test_auth_handler("test-auth-a"), - }, - token=None, - sign_in_failed=True, - ) - mock_turn_state = AsyncMock(get_value=Mock(return_value=SignInState())) - auth.on_sign_in_failure(AsyncMock()) - - token_res = await auth.begin_or_continue_flow( - self.turn_context, - mock_turn_state, - "auth-handler", - ) - - # Test value propogation - auth_handler = auth.resolver_handler("auth-handler") - assert not token_res - - # Test function calls - auth_handler.flow._get_flow_state.assert_called_once() - auth_handler.flow.continue_flow.assert_called_once() - mock_turn_state.save.assert_called_once_with(self.turn_context) - auth._sign_in_failed_handler.assert_called_once_with( - self.turn_context, mock_turn_state, "auth-handler" - ) +import pytest + +import jwt + +from microsoft.agents.activity import ActivityTypes, TokenResponse +from microsoft.agents.hosting.core import MemoryStorage +from microsoft.agents.hosting.core.storage.storage_test_utils import StorageBaseline +from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase +from microsoft.agents.hosting.core.connector.user_token_client_base import ( + UserTokenClientBase, +) + +from microsoft.agents.hosting.core.app.oauth import Authorization +from microsoft.agents.hosting.core.oauth import ( + FlowStorageClient, + FlowErrorTag, + FlowStateTag, + FlowResponse, + OAuthFlow, +) + +# test constants +from .tools.testing_oauth import * +from .tools.testing_authorization import ( + TestingConnectionManager as MockConnectionManager, + create_test_auth_handler, +) + + +class TestUtils: + + def create_context( + self, + mocker, + channel_id="__channel_id", + user_id="__user_id", + user_token_client=None, + ): + + if not user_token_client: + user_token_client = self.create_mock_user_token_client(mocker) + + turn_context = mocker.Mock() + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = user_id + turn_context.activity.type = ActivityTypes.message + turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" + turn_context.adapter.AGENT_IDENTITY_KEY = "__agent_identity_key" + agent_identity = mocker.Mock() + agent_identity.claims = {"aud": MS_APP_ID} + turn_context.turn_state = { + "__user_token_client": user_token_client, + "__agent_identity_key": agent_identity, + } + return turn_context + + def create_mock_user_token_client( + self, + mocker, + token=RES_TOKEN, + ): + mock_user_token_client_class = mocker.Mock(spec=UserTokenClientBase) + mock_user_token_client_class.user_token = mocker.Mock(spec=UserTokenBase) + mock_user_token_client_class.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() if not token else TokenResponse(token=token) + ) + mock_user_token_client_class.user_token.sign_out = mocker.AsyncMock() + return mock_user_token_client_class + + @pytest.fixture + def mock_user_token_client_class(self, mocker): + return self.create_mock_user_token_client(mocker) + + def create_mock_oauth_flow_class(self, mocker, token_response): + mock_oauth_flow_class = mocker.Mock(spec=OAuthFlow) + # mock_oauth_flow_class.get_user_token = mocker.AsyncMock(return_value=token_response) + # mock_oauth_flow_class.sign_out = mocker.AsyncMock() + mocker.patch.object(OAuthFlow, "get_user_token", return_value=token_response) + mocker.patch.object(OAuthFlow, "sign_out") + return mock_oauth_flow_class + + @pytest.fixture + def mock_oauth_flow_class(self, mocker): + return self.create_mock_oauth_flow_class(mocker, TokenResponse(token=RES_TOKEN)) + # mock_flow_class = mocker.Mock(spec=OAuthFlow) + + # # mocker.patch.object(OAuthFlow, "__init__", return_value=mock_flow_class) + # mock_flow_class.get_user_token = mocker.AsyncMock(return_value=TokenResponse(token=RES_TOKEN)) + # mock_flow_class.sign_out = mocker.AsyncMock() + # mocker.patch.object(OAuthFlow, "get_user_token") + + # return mock_flow_class + + @pytest.fixture + def turn_context(self, mocker): + return self.create_context(mocker, "__channel_id", "__user_id", "__connection") + + def create_user_token_client(self, mocker, get_token_return=""): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if isinstance(get_token_return, TokenResponse): + return_value = get_token_return + elif get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + @pytest.fixture + def user_token_client(self, mocker): + return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) + + @pytest.fixture + def auth_handlers(self): + handlers = {} + for key in STORAGE_INIT_DATA().keys(): + if key.startswith("auth/"): + auth_handler_name = key[key.rindex("/") + 1 :] + handlers[auth_handler_name] = create_test_auth_handler( + auth_handler_name, True + ) + return handlers + + @pytest.fixture + def connection_manager(self): + return MockConnectionManager() + + @pytest.fixture + def auth(self, connection_manager, storage, auth_handlers): + return Authorization(storage, connection_manager, auth_handlers) + + +class TestAuthorizationUtils(TestUtils): + + def create_storage(self): + return MemoryStorage(STORAGE_INIT_DATA()) + + @pytest.fixture + def storage(self): + return self.create_storage() + + @pytest.fixture + def baseline_storage(self): + return StorageBaseline(STORAGE_INIT_DATA()) + + def patch_flow( + self, + mocker, + flow_response=None, + token=None, + ): + mocker.patch.object( + OAuthFlow, "get_user_token", return_value=TokenResponse(token=token) + ) + mocker.patch.object(OAuthFlow, "sign_out") + mocker.patch.object( + OAuthFlow, "begin_or_continue_flow", return_value=flow_response + ) + + +class TestAuthorization(TestAuthorizationUtils): + + def test_init_configuration_variants( + self, storage, connection_manager, auth_handlers + ): + """Test initialization of authorization with different configuration variants.""" + AGENTAPPLICATION = { + "USERAUTHORIZATION": { + "HANDLERS": { + handler_name: { + "SETTINGS": { + "title": handler.title, + "text": handler.text, + "abs_oauth_connection_name": handler.abs_oauth_connection_name, + "obo_connection_name": handler.obo_connection_name, + } + } + for handler_name, handler in auth_handlers.items() + } + } + } + auth_with_config_obj = Authorization( + storage, + connection_manager, + auth_handlers=None, + AGENTAPPLICATION=AGENTAPPLICATION, + ) + auth_with_handlers_list = Authorization( + storage, connection_manager, auth_handlers=auth_handlers + ) + for auth_handler_name in auth_handlers.keys(): + auth_handler_a = auth_with_config_obj.resolve_handler(auth_handler_name) + auth_handler_b = auth_with_handlers_list.resolve_handler(auth_handler_name) + + assert auth_handler_a.name == auth_handler_b.name + assert auth_handler_a.title == auth_handler_b.title + assert auth_handler_a.text == auth_handler_b.text + assert ( + auth_handler_a.abs_oauth_connection_name + == auth_handler_b.abs_oauth_connection_name + ) + assert ( + auth_handler_a.obo_connection_name == auth_handler_b.obo_connection_name + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id, channel_id, user_id", + [["missing", "webchat", "Alice"], ["handler", "teams", "Bob"]], + ) + async def test_open_flow_value_error( + self, mocker, auth, auth_handler_id, channel_id, user_id + ): + """Test opening a flow with a missing auth handler.""" + context = self.create_context(mocker, channel_id, user_id) + with pytest.raises(ValueError): + async with auth.open_flow(context, auth_handler_id): + pass + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "auth_handler_id, channel_id, user_id", + [ + ["", "webchat", "Alice"], + ["graph", "teams", "Bob"], + ["slack", "webchat", "Chuck"], + ], + ) + async def test_open_flow_readonly( + self, + mocker, + storage, + connection_manager, + auth_handlers, + auth_handler_id, + channel_id, + user_id, + ): + """Test opening a flow and not modifying it.""" + # setup + context = self.create_context(mocker, channel_id, user_id) + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + + # verify + actual_flow_state = await flow_storage_client.read( + auth.resolve_handler(auth_handler_id).name + ) + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_complete_flow( + self, + mocker, + storage, + connection_manager, + mock_user_token_client_class, + auth_handlers, + ): + # setup + channel_id = "teams" + user_id = "Alice" + auth_handler_id = "graph" + + self.create_user_token_client(mocker, get_token_return=RES_TOKEN) + + context = self.create_context(mocker, channel_id, user_id) + context.activity.type = ActivityTypes.message + context.activity.text = "123456" + + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.user_token = RES_TOKEN + + flow_response = await flow.begin_or_continue_flow(context.activity) + res_flow_state = flow_response.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expiration = ( + res_flow_state.expiration + ) # we won't check this for now + + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_failure( + self, + mocker, + storage, + connection_manager, + auth_handlers, + ): + # setup + channel_id = "teams" + user_id = "Bob" + auth_handler_id = "slack" + + context = self.create_context(mocker, channel_id, user_id) + context.activity.text = "invalid_magic_code" + + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + expected_flow_state.attempts_remaining = 0 + + flow_response = await flow.begin_or_continue_flow(context.activity) + res_flow_state = flow_response.flow_state + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expiration = ( + actual_flow_state.expiration + ) # we won't check this for now + + assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT + assert res_flow_state == expected_flow_state + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_open_flow_success_modified_signout( + self, mocker, storage, connection_manager, auth_handlers + ): + # setup + channel_id = "webchat" + user_id = "Alice" + auth_handler_id = "graph" + + context = self.create_context(mocker, channel_id, user_id) + + auth = Authorization(storage, connection_manager, auth_handlers) + flow_storage_client = FlowStorageClient(channel_id, user_id, storage) + + # test + async with auth.open_flow(context, auth_handler_id) as flow: + expected_flow_state = flow.flow_state + expected_flow_state.tag = FlowStateTag.NOT_STARTED + expected_flow_state.user_token = "" + + await flow.sign_out() + + # verify + actual_flow_state = await flow_storage_client.read(auth_handler_id) + expected_flow_state.expiration = ( + actual_flow_state.expiration + ) # we won't check this for now + assert actual_flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_get_token_success(self, mocker, auth): + mock_user_token_client_class = self.create_user_token_client( + mocker, get_token_return=TokenResponse(token="token") + ) + context = self.create_context( + mocker, + "__channel_id", + "__user_id", + user_token_client=mock_user_token_client_class, + ) + assert await auth.get_token(context, "slack") == TokenResponse(token="token") + mock_user_token_client_class.user_token.get_token.assert_called_once() + + @pytest.mark.asyncio + async def test_get_token_empty_response(self, mocker, auth): + mock_user_token_client_class = self.create_user_token_client( + mocker, get_token_return=TokenResponse() + ) + context = self.create_context( + mocker, + "__channel_id", + "__user_id", + user_token_client=mock_user_token_client_class, + ) + assert await auth.get_token(context, "graph") == TokenResponse() + mock_user_token_client_class.user_token.get_token.assert_called_once() + + @pytest.mark.asyncio + async def test_get_token_error( + self, turn_context, storage, connection_manager, auth_handlers + ): + auth = Authorization(storage, connection_manager, auth_handlers) + with pytest.raises(ValueError): + await auth.get_token(turn_context, "missing-handler") + + @pytest.mark.asyncio + async def test_exchange_token_no_token(self, mocker, turn_context, auth): + self.create_mock_oauth_flow_class(mocker, TokenResponse()) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse() + + @pytest.mark.asyncio + async def test_exchange_token_not_exchangeable(self, mocker, turn_context, auth): + token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") + self.create_mock_oauth_flow_class( + mocker, TokenResponse(connection_name="github", token=token) + ) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse() + + @pytest.mark.asyncio + async def test_exchange_token_valid_exchangeable(self, turn_context, mocker, auth): + token = jwt.encode({"aud": "api://botframework.test.api"}, "") + self.create_mock_oauth_flow_class( + mocker, TokenResponse(connection_name="github", token=token) + ) + mock_user_token_client_class = self.create_mock_user_token_client( + mocker, token=token + ) + mock_user_token_client_class.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse( + scopes=["scope"], token=token, connection_name="github" + ) + ) + res = await auth.exchange_token(turn_context, ["scope"], "github") + assert res == TokenResponse(token="github-obo-connection-obo-token") + + @pytest.mark.asyncio + async def test_get_active_flow_state(self, mocker, auth): + context = self.create_context(mocker, "webchat", "Alice") + actual_flow_state = await auth.get_active_flow_state(context) + assert ( + actual_flow_state + == STORAGE_SAMPLE_DICT[flow_key("webchat", "Alice", "github")] + ) + + @pytest.mark.asyncio + async def test_get_active_flow_state_missing(self, mocker, auth): + context = self.create_context(mocker, "__channel_id", "__user_id") + res = await auth.get_active_flow_state(context) + assert res is None + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_success(self, mocker, auth): + # robrandao: TODO -> lower priority -> more testing here + # setup + mocker.patch.object( + OAuthFlow, + "begin_or_continue_flow", + return_value=FlowResponse( + token_response=TokenResponse(token="token"), + flow_state=FlowState( + tag=FlowStateTag.COMPLETE, auth_handler_id="github" + ), + ), + ) + context = self.create_context(mocker, "webchat", "Alice") + + context.dummy_val = None + + def on_sign_in_success(context, turn_state, auth_handler_id): + context.dummy_val = auth_handler_id + + def on_sign_in_failure(context, turn_state, auth_handler_id, err): + context.dummy_val = str(err) + + # test + auth.on_sign_in_success(on_sign_in_success) + auth.on_sign_in_failure(on_sign_in_failure) + flow_response = await auth.begin_or_continue_flow(context, None, "github") + assert context.dummy_val == "github" + assert flow_response.token_response == TokenResponse(token="token") + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_already_completed(self, mocker, auth): + # robrandao: TODO -> lower priority -> more testing here + # setup + context = self.create_context(mocker, "webchat", "Alice") + + context.dummy_val = None + + def on_sign_in_success(context, turn_state, auth_handler_id): + context.dummy_val = auth_handler_id + + def on_sign_in_failure(context, turn_state, auth_handler_id, err): + context.dummy_val = str(err) + + # test + auth.on_sign_in_success(on_sign_in_success) + auth.on_sign_in_failure(on_sign_in_failure) + flow_response = await auth.begin_or_continue_flow(context, None, "graph") + assert context.dummy_val == None + assert flow_response.token_response == TokenResponse(token="test_token") + assert flow_response.continuation_activity is None + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_failure( + self, mocker, mock_oauth_flow_class, auth + ): + # robrandao: TODO -> lower priority -> more testing here + # setup + mocker.patch.object( + OAuthFlow, + "begin_or_continue_flow", + return_value=FlowResponse( + token_response=TokenResponse(token="token"), + flow_state=FlowState( + tag=FlowStateTag.FAILURE, auth_handler_id="github" + ), + flow_state_error=FlowErrorTag.MAGIC_FORMAT, + ), + ) + context = self.create_context(mocker, "webchat", "Alice") + + context.dummy_val = None + + def on_sign_in_success(context, turn_state, auth_handler_id): + context.dummy_val = auth_handler_id + + def on_sign_in_failure(context, turn_state, auth_handler_id, err): + context.dummy_val = str(err) + + # test + auth.on_sign_in_success(on_sign_in_success) + auth.on_sign_in_failure(on_sign_in_failure) + flow_response = await auth.begin_or_continue_flow(context, None, "github") + assert context.dummy_val == "FlowErrorTag.NONE" + assert flow_response.token_response == TokenResponse(token="token") + + @pytest.mark.parametrize("auth_handler_id", ["graph", "github"]) + def test_resolve_handler_specified(self, auth, auth_handlers, auth_handler_id): + assert auth.resolve_handler(auth_handler_id) == auth_handlers[auth_handler_id] + + def test_resolve_handler_error(self, auth): + with pytest.raises(ValueError): + auth.resolve_handler("missing-handler") + + def test_resolve_handler_first(self, auth, auth_handlers): + assert auth.resolve_handler() == next(iter(auth_handlers.values())) + + @pytest.mark.asyncio + async def test_sign_out_individual( + self, + mocker, + mock_user_token_client_class, + mock_oauth_flow_class, + storage, + baseline_storage, + connection_manager, + auth_handlers, + ): + # setup + storage_client = FlowStorageClient("teams", "Alice", storage) + context = self.create_context(mocker, "teams", "Alice") + auth = Authorization(storage, connection_manager, auth_handlers) + + # test + await auth.sign_out(context, "graph") + + # verify + assert ( + await storage.read([storage_client.key("graph")], target_cls=FlowState) + == {} + ) + OAuthFlow.sign_out.assert_called_once() + + @pytest.mark.asyncio + async def test_sign_out_all( + self, + mocker, + mock_user_token_client_class, + mock_oauth_flow_class, + turn_context, + storage, + baseline_storage, + connection_manager, + auth_handlers, + ): + # setup + storage_client = FlowStorageClient("webchat", "Alice", storage) + + auth = Authorization(storage, connection_manager, auth_handlers) + context = self.create_context(mocker, "webchat", "Alice") + await auth.sign_out(context) + + # verify + assert ( + await storage.read([storage_client.key("graph")], target_cls=FlowState) + == {} + ) + assert ( + await storage.read([storage_client.key("github")], target_cls=FlowState) + == {} + ) + assert ( + await storage.read([storage_client.key("slack")], target_cls=FlowState) + == {} + ) + OAuthFlow.sign_out.assert_called() # ignore red squiggly -> mocked diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py new file mode 100644 index 00000000..84c669e6 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_state.py @@ -0,0 +1,237 @@ +from datetime import datetime + +import pytest + +from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag + + +class TestFlowState: + + @pytest.mark.parametrize( + "original_flow_state, refresh_to_not_started", + [ + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=-1, + expiration=datetime.now().timestamp(), + ), + False, + ), + ], + ) + def test_refresh(self, original_flow_state, refresh_to_not_started): + new_flow_state = original_flow_state.model_copy() + new_flow_state.refresh() + expected_flow_state = original_flow_state.model_copy() + if refresh_to_not_started: + expected_flow_state.tag = FlowStateTag.NOT_STARTED + assert new_flow_state == expected_flow_state + + @pytest.mark.parametrize( + "flow_state, expected", + [ + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=-1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ], + ) + def test_is_expired(self, flow_state, expected): + assert flow_state.is_expired() == expected + + @pytest.mark.parametrize( + "flow_state, expected", + [ + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() - 100, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=-1, + expiration=datetime.now().timestamp(), + ), + True, + ), + ], + ) + def test_reached_max_attempts(self, flow_state, expected): + assert flow_state.reached_max_attempts() == expected + + @pytest.mark.parametrize( + "flow_state, expected", + [ + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=0, + expiration=datetime.now().timestamp(), + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + expiration=datetime.now().timestamp(), + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=0, + expiration=datetime.now().timestamp() - 100, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=1, + expiration=datetime.now().timestamp() - 100, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + expiration=datetime.now().timestamp() + 1000, + ), + True, + ), + ( + FlowState( + tag=FlowStateTag.BEGIN, + attempts_remaining=0, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.COMPLETE, + attempts_remaining=-1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.FAILURE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + False, + ), + ( + FlowState( + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + expiration=datetime.now().timestamp() + 1000, + ), + True, + ), + ], + ) + def test_is_active(self, flow_state, expected): + assert flow_state.is_active() == expected diff --git a/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py new file mode 100644 index 00000000..caa65b29 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/test_flow_storage_client.py @@ -0,0 +1,170 @@ +import pytest + +from microsoft.agents.hosting.core.storage import MemoryStorage +from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem +from microsoft.agents.hosting.core.oauth import FlowState, FlowStorageClient + + +class TestFlowStorageClient: + + @pytest.fixture + def channel_id(self): + return "__channel_id" + + @pytest.fixture + def user_id(self): + return "__user_id" + + @pytest.fixture + def storage(self): + return MemoryStorage() + + @pytest.fixture + def client(self, channel_id, user_id, storage): + return FlowStorageClient(channel_id, user_id, storage) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "channel_id, user_id", + [ + ("channel_id", "user_id"), + ("teams_id", "Bob"), + ("channel", "Alice"), + ], + ) + async def test_init_base_key(self, mocker, channel_id, user_id): + client = FlowStorageClient(channel_id, user_id, mocker.Mock()) + assert client.base_key == f"auth/{channel_id}/{user_id}/" + + @pytest.mark.asyncio + async def test_init_fails_without_user_id(self, channel_id, storage): + with pytest.raises(ValueError): + FlowStorageClient(channel_id, "", storage) + + @pytest.mark.asyncio + async def test_init_fails_without_channel_id(self, user_id, storage): + with pytest.raises(ValueError): + FlowStorageClient("", user_id, storage) + + @pytest.mark.parametrize( + "auth_handler_id, expected", + [ + ("handler", "auth/__channel_id/__user_id/handler"), + ("auth_handler", "auth/__channel_id/__user_id/auth_handler"), + ], + ) + def test_key(self, client, auth_handler_id, expected): + assert client.key(auth_handler_id) == expected + + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_handler_id", ["handler", "auth_handler"]) + async def test_read(self, mocker, user_id, channel_id, auth_handler_id): + storage = mocker.AsyncMock() + key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" + storage.read.return_value = {key: FlowState()} + client = FlowStorageClient(channel_id, user_id, storage) + res = await client.read(auth_handler_id) + assert res is storage.read.return_value[key] + storage.read.assert_called_once_with( + [client.key(auth_handler_id)], target_cls=FlowState + ) + + @pytest.mark.asyncio + async def test_read_missing(self, mocker): + storage = mocker.AsyncMock() + storage.read.return_value = {} + client = FlowStorageClient("__channel_id", "__user_id", storage) + res = await client.read("non_existent_handler") + assert res is None + storage.read.assert_called_once_with( + [client.key("non_existent_handler")], target_cls=FlowState + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_handler_id", ["handler", "auth_handler"]) + async def test_write(self, mocker, channel_id, user_id, auth_handler_id): + storage = mocker.AsyncMock() + storage.write.return_value = None + client = FlowStorageClient(channel_id, user_id, storage) + flow_state = mocker.Mock(spec=FlowState) + flow_state.auth_handler_id = auth_handler_id + await client.write(flow_state) + storage.write.assert_called_once_with({client.key(auth_handler_id): flow_state}) + + @pytest.mark.asyncio + @pytest.mark.parametrize("auth_handler_id", ["handler", "auth_handler"]) + async def test_delete(self, mocker, channel_id, user_id, auth_handler_id): + storage = mocker.AsyncMock() + storage.delete.return_value = None + client = FlowStorageClient(channel_id, user_id, storage) + await client.delete(auth_handler_id) + storage.delete.assert_called_once_with([client.key(auth_handler_id)]) + + @pytest.mark.asyncio + async def test_integration_with_memory_storage(self, channel_id, user_id): + + flow_state_alpha = FlowState(auth_handler_id="handler", flow_started=True) + flow_state_beta = FlowState( + auth_handler_id="auth_handler", flow_started=True, user_token="token" + ) + + storage = MemoryStorage( + { + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + } + ) + baseline = MemoryStorage( + { + "some_data": MockStoreItem({"value": "test"}), + f"auth/{channel_id}/{user_id}/handler": flow_state_alpha, + f"fauth/{channel_id}/{user_id}/auth_handler": flow_state_beta, + } + ) + + # helpers + async def read_check(*args, **kwargs): + res_storage = await storage.read(*args, **kwargs) + res_baseline = await baseline.read(*args, **kwargs) + assert res_storage == res_baseline + + async def write_both(*args, **kwargs): + await storage.write(*args, **kwargs) + await baseline.write(*args, **kwargs) + + async def delete_both(*args, **kwargs): + await storage.delete(*args, **kwargs) + await baseline.delete(*args, **kwargs) + + client = FlowStorageClient(channel_id, user_id, storage) + + new_flow_state_alpha = FlowState(auth_handler_id="handler") + flow_state_chi = FlowState(auth_handler_id="chi") + + await client.write(new_flow_state_alpha) + await client.write(flow_state_chi) + await baseline.write( + {f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()} + ) + await baseline.write( + {f"auth/{channel_id}/{user_id}/chi": flow_state_chi.model_copy()} + ) + + await write_both( + {f"auth/{channel_id}/{user_id}/handler": new_flow_state_alpha.model_copy()} + ) + await write_both( + {f"auth/{channel_id}/{user_id}/auth_handler": flow_state_beta.model_copy()} + ) + await write_both({"other_data": MockStoreItem({"value": "more"})}) + + await delete_both(["some_data"]) + + await read_check([f"auth/{channel_id}/{user_id}/handler"], target_cls=FlowState) + await read_check( + [f"auth/{channel_id}/{user_id}/auth_handler"], target_cls=FlowState + ) + await read_check([f"auth/{channel_id}/{user_id}/chi"], target_cls=FlowState) + await read_check(["other_data"], target_cls=MockStoreItem) + await read_check(["some_data"], target_cls=MockStoreItem) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py new file mode 100644 index 00000000..c99d8886 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/test_oauth_flow.py @@ -0,0 +1,602 @@ +import pytest + +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + TokenResponse, + SignInResource, + TokenExchangeState, + ConversationReference, + ChannelAccount, +) +from microsoft.agents.hosting.core.oauth import ( + OAuthFlow, + FlowErrorTag, + FlowStateTag, + FlowResponse, +) +from microsoft.agents.hosting.core.connector.user_token_base import UserTokenBase +from microsoft.agents.hosting.core.connector.user_token_client_base import ( + UserTokenClientBase, +) + +# test constants +from .tools.testing_oauth import * + + +class TestOAuthFlowUtils: + + def create_user_token_client(self, mocker, get_token_return=None): + + user_token_client = mocker.Mock(spec=UserTokenClientBase) + user_token_client.user_token = mocker.Mock(spec=UserTokenBase) + user_token_client.user_token.get_token = mocker.AsyncMock() + user_token_client.user_token.sign_out = mocker.AsyncMock() + + return_value = TokenResponse() + if get_token_return: + return_value = TokenResponse(token=get_token_return) + user_token_client.user_token.get_token.return_value = return_value + + return user_token_client + + @pytest.fixture + def user_token_client(self, mocker): + return self.create_user_token_client(mocker, get_token_return=RES_TOKEN) + + def create_activity( + self, + mocker, + activity_type=ActivityTypes.message, + name="a", + value=None, + text="a", + ): + # def conv_ref(): + # return mocker.MagicMock(spec=ConversationReference) + mock_conversation_ref = mocker.MagicMock(ConversationReference) + mocker.patch.object( + Activity, + "get_conversation_reference", + return_value=mocker.MagicMock(ConversationReference), + ) + # mocker.patch.object(ConversationReference, "create", return_value=conv_ref()) + return Activity( + type=activity_type, + name=name, + from_property=ChannelAccount(id=USER_ID), + channel_id=CHANNEL_ID, + # get_conversation_reference=mocker.Mock(return_value=conv_ref), + relates_to=mocker.MagicMock(ConversationReference), + value=value, + text=text, + ) + + @pytest.fixture(params=FLOW_STATES.ALL()) + def sample_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=FLOW_STATES.FAILED()) + def sample_failed_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture(params=FLOW_STATES.INACTIVE()) + def sample_inactive_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture( + params=[ + flow_state + for flow_state in FLOW_STATES.INACTIVE() + if flow_state.tag != FlowStateTag.COMPLETE + ] + ) + def sample_inactive_flow_state_not_completed(self, request): + return request.param.model_copy() + + @pytest.fixture(params=FLOW_STATES.ACTIVE()) + def sample_active_flow_state(self, request): + return request.param.model_copy() + + @pytest.fixture + def flow(self, sample_flow_state, user_token_client): + return OAuthFlow(sample_flow_state, user_token_client) + + +class TestOAuthFlow(TestOAuthFlowUtils): + + def test_init_no_user_token_client(self, sample_flow_state): + with pytest.raises(ValueError): + OAuthFlow(sample_flow_state, None) + + @pytest.mark.parametrize( + "missing_value", ["connection", "ms_app_id", "channel_id", "user_id"] + ) + def test_init_errors(self, missing_value, user_token_client): + flow_state = FLOW_STATES.STARTED_FLOW.model_copy() + flow_state.__setattr__(missing_value, None) + with pytest.raises(ValueError): + OAuthFlow(flow_state, user_token_client) + flow_state.__setattr__(missing_value, "") + with pytest.raises(ValueError): + OAuthFlow(flow_state, user_token_client) + + def test_init_with_state(self, sample_flow_state, user_token_client): + flow = OAuthFlow(sample_flow_state, user_token_client) + assert flow.flow_state == sample_flow_state + + def test_flow_state_prop_copy(self, flow): + flow_state = flow.flow_state + flow_state.user_id = flow_state.user_id + "_copy" + assert flow.flow_state.user_id == USER_ID + assert flow_state.user_id == f"{USER_ID}_copy" + + @pytest.mark.asyncio + async def test_get_user_token_success(self, sample_flow_state, user_token_client): + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_final_flow_state = sample_flow_state + expected_final_flow_state.user_token = RES_TOKEN + + # test + token_response = await flow.get_user_token() + token = token_response.token + + # verify + assert token == RES_TOKEN + expected_final_flow_state.expiration = flow.flow_state.expiration + assert flow.flow_state == expected_final_flow_state + user_token_client.user_token.get_token.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + code=None, + ) + + @pytest.mark.asyncio + async def test_get_user_token_failure(self, mocker, sample_flow_state): + # setup + user_token_client = self.create_user_token_client(mocker, get_token_return=None) + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_final_flow_state = flow.flow_state + + # test + token_response = await flow.get_user_token() + + # verify + assert token_response == TokenResponse() + assert flow.flow_state == expected_final_flow_state + user_token_client.user_token.get_token.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + code=None, + ) + + @pytest.mark.asyncio + async def test_sign_out(self, sample_flow_state, user_token_client): + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = "" + expected_flow_state.tag = FlowStateTag.NOT_STARTED + + # test + await flow.sign_out() + + # verify + user_token_client.user_token.sign_out.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + ) + assert flow.flow_state == expected_flow_state + + @pytest.mark.asyncio + async def test_begin_flow_easy_case( + self, mocker, sample_flow_state, user_token_client + ): + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + activity = mocker.Mock(spec=Activity) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = RES_TOKEN + + # test + response = await flow.begin_flow(activity) + + # verify + flow_state = flow.flow_state + expected_flow_state.expiration = flow_state.expiration + assert flow_state == expected_flow_state + + assert response.flow_state == flow_state + assert response.sign_in_resource is None # No sign-in resource in this case + assert response.flow_error_tag == FlowErrorTag.NONE + assert response.token_response.token == RES_TOKEN + user_token_client.user_token.get_token.assert_called_once_with( + user_id=USER_ID, + connection_name=ABS_OAUTH_CONNECTION_NAME, + channel_id=CHANNEL_ID, + code=None, + ) + + @pytest.mark.asyncio + async def test_begin_flow_long_case( + self, mocker, sample_flow_state, user_token_client + ): + # mock + # tes = mocker.Mock(TokenExchangeState) + # tes.get_encoded_state = mocker.Mock(return_value="encoded_state") + mocker.patch.object( + TokenExchangeState, "get_encoded_state", return_value="encoded_state" + ) + dummy_sign_in_resource = SignInResource( + sign_in_link="https://example.com/signin", + token_exchange_state=mocker.Mock( + TokenExchangeState, + get_encoded_state=mocker.Mock(return_value="encoded_state"), + ), + ) + user_token_client.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + user_token_client.agent_sign_in.get_sign_in_resource = mocker.AsyncMock( + return_value=dummy_sign_in_resource + ) + activity = self.create_activity(mocker) + + # setup + flow = OAuthFlow(sample_flow_state, user_token_client) + expected_flow_state = sample_flow_state + expected_flow_state.user_token = "" + expected_flow_state.tag = FlowStateTag.BEGIN + expected_flow_state.attempts_remaining = 3 + expected_flow_state.continuation_activity = activity + + # test + response = await flow.begin_flow(activity) + + # verify flow_state + flow_state = flow.flow_state + expected_flow_state.expiration = ( + flow_state.expiration + ) # robrandao: TODO -> ignore this for now + assert flow_state == response.flow_state + assert flow_state == expected_flow_state + + # verify FlowResponse + assert response.sign_in_resource == dummy_sign_in_resource + assert response.flow_error_tag == FlowErrorTag.NONE + assert not response.token_response + # robrandao: TODO more assertions on sign_in_resource + + @pytest.mark.asyncio + async def test_continue_flow_not_active( + self, mocker, sample_inactive_flow_state, user_token_client + ): + # setup + activity = mocker.Mock() + flow = OAuthFlow(sample_inactive_flow_state, user_token_client) + expected_flow_state = sample_inactive_flow_state + expected_flow_state.tag = FlowStateTag.FAILURE + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + + # verify + assert flow_state == expected_flow_state + assert flow_response.flow_state == flow_state + assert not flow_response.token_response + + async def helper_continue_flow_failure( + self, active_flow_state, user_token_client, activity, flow_error_tag + ): + # setup + flow = OAuthFlow(active_flow_state, user_token_client) + expected_flow_state = active_flow_state + expected_flow_state.tag = ( + FlowStateTag.CONTINUE + if active_flow_state.attempts_remaining > 1 + else FlowStateTag.FAILURE + ) + expected_flow_state.attempts_remaining = ( + active_flow_state.attempts_remaining - 1 + ) + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + + # verify + assert flow_response.flow_state == flow_state + assert expected_flow_state == flow_state + assert flow_response.token_response == TokenResponse() + assert flow_response.flow_error_tag == flow_error_tag + + async def helper_continue_flow_success( + self, active_flow_state, user_token_client, activity + ): + # setup + flow = OAuthFlow(active_flow_state, user_token_client) + expected_flow_state = active_flow_state + expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.user_token = RES_TOKEN + expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining + + # test + flow_response = await flow.continue_flow(activity) + flow_state = flow.flow_state + expected_flow_state.expiration = ( + flow_state.expiration + ) # robrandao: TODO -> ignore this for now + + # verify + assert flow_response.flow_state == flow_state + assert expected_flow_state == flow_state + assert flow_response.token_response == TokenResponse(token=RES_TOKEN) + assert flow_response.flow_error_tag == FlowErrorTag.NONE + + @pytest.mark.asyncio + @pytest.mark.parametrize("magic_code", ["magic", "123", "", "1239453"]) + async def test_continue_flow_active_message_magic_format_error( + self, mocker, sample_active_flow_state, user_token_client, magic_code + ): + # setup + activity = self.create_activity(mocker, ActivityTypes.message, text=magic_code) + await self.helper_continue_flow_failure( + sample_active_flow_state, + user_token_client, + activity, + FlowErrorTag.MAGIC_FORMAT, + ) + user_token_client.assert_not_called() + + @pytest.mark.asyncio + async def test_continue_flow_active_message_magic_code_error( + self, mocker, sample_active_flow_state, user_token_client + ): + # setup + user_token_client.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + activity = self.create_activity(mocker, ActivityTypes.message, text="123456") + await self.helper_continue_flow_failure( + sample_active_flow_state, + user_token_client, + activity, + FlowErrorTag.MAGIC_CODE_INCORRECT, + ) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + code="123456", + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_message_success( + self, mocker, sample_active_flow_state, user_token_client + ): + # setup + activity = self.create_activity(mocker, ActivityTypes.message, text="123456") + await self.helper_continue_flow_success( + sample_active_flow_state, user_token_client, activity + ) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + code="123456", + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_verify_state_error( + self, mocker, sample_active_flow_state, user_token_client + ): + # setup + user_token_client.user_token.get_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/verifyState", + value={"state": "magic_code"}, + ) + await self.helper_continue_flow_failure( + sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER + ) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + code="magic_code", + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_verify_success( + self, mocker, sample_active_flow_state, user_token_client + ): + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/verifyState", + value={"state": "magic_code"}, + ) + await self.helper_continue_flow_success( + sample_active_flow_state, user_token_client, activity + ) + user_token_client.user_token.get_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + code="magic_code", + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_token_exchange_error( + self, mocker, sample_active_flow_state, user_token_client + ): + token_exchange_request = {} + user_token_client.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse() + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/tokenExchange", + value=token_exchange_request, + ) + await self.helper_continue_flow_failure( + sample_active_flow_state, user_token_client, activity, FlowErrorTag.OTHER + ) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request, + ) + + @pytest.mark.asyncio + async def test_continue_flow_active_sign_in_token_exchange_success( + self, mocker, sample_active_flow_state, user_token_client + ): + token_exchange_request = {} + user_token_client.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse(token=RES_TOKEN) + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/tokenExchange", + value=token_exchange_request, + ) + await self.helper_continue_flow_success( + sample_active_flow_state, user_token_client, activity + ) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request, + ) + + @pytest.mark.asyncio + async def test_continue_flow_invalid_invoke_name( + self, mocker, sample_active_flow_state, user_token_client + ): + with pytest.raises(ValueError): + activity = self.create_activity( + mocker, ActivityTypes.invoke, name="other", value={} + ) + flow = OAuthFlow(sample_active_flow_state, user_token_client) + await flow.continue_flow(activity) + + @pytest.mark.asyncio + async def test_continue_flow_invalid_activity_type( + self, mocker, sample_active_flow_state, user_token_client + ): + with pytest.raises(ValueError): + activity = self.create_activity( + mocker, ActivityTypes.command, name="other", value={} + ) + flow = OAuthFlow(sample_active_flow_state, user_token_client) + await flow.continue_flow(activity) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_not_started_flow(self, mocker): + # setup + sample_flow_state = FLOW_STATES.NOT_STARTED_FLOW.model_copy() + expected_response = FlowResponse( + flow_state=sample_flow_state, + token_response=TokenResponse(token=sample_flow_state.user_token), + ) + mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + + activity_mock = mocker.Mock() + flow = OAuthFlow(sample_flow_state, mocker.Mock()) + + # test + actual_response = await flow.begin_or_continue_flow(activity_mock) + + # verify + assert actual_response is expected_response + OAuthFlow.begin_flow.assert_called_once_with(activity_mock) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_inactive_flow( + self, + mocker, + sample_inactive_flow_state_not_completed, + ): + # setup + expected_response = FlowResponse( + flow_state=sample_inactive_flow_state_not_completed, + token_response=TokenResponse(), + ) + mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + + flow = OAuthFlow(sample_inactive_flow_state_not_completed, mocker.Mock()) + + # test + activity_mock = mocker.Mock() + actual_response = await flow.begin_or_continue_flow(activity_mock) + + # verify + assert actual_response is expected_response + OAuthFlow.begin_flow.assert_called_once_with(activity_mock) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_active_flow( + self, + mocker, + sample_active_flow_state, + ): + # setup + expected_response = FlowResponse( + flow_state=sample_active_flow_state, + token_response=TokenResponse(token=sample_active_flow_state.user_token), + ) + mocker.patch.object(OAuthFlow, "continue_flow", return_value=expected_response) + + flow = OAuthFlow(sample_active_flow_state, mocker.Mock()) + + # test + activity_mock = mocker.Mock() + actual_response = await flow.begin_or_continue_flow(activity_mock) + + # verify + assert actual_response is expected_response + OAuthFlow.continue_flow.assert_called_once_with(activity_mock) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_stale_flow_state(self, mocker): + flow_state = FLOW_STATES.ACTIVE_EXP_FLOW.model_copy() + expected_response = FlowResponse() + + mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + + flow = OAuthFlow(flow_state, mocker.Mock()) + actual_response = await flow.begin_or_continue_flow(None) + + assert actual_response is expected_response + OAuthFlow.begin_flow.assert_called_once_with(None) + + @pytest.mark.asyncio + async def test_begin_or_continue_flow_completed_flow_state(self, mocker): + flow_state = FLOW_STATES.COMPLETED_FLOW.model_copy() + expected_response = FlowResponse( + flow_state=flow_state, + token_response=TokenResponse(token=flow_state.user_token), + ) + mocker.patch.object(OAuthFlow, "begin_flow", return_value=None) + mocker.patch.object(OAuthFlow, "continue_flow", return_value=None) + + flow = OAuthFlow(flow_state, mocker.Mock()) + actual_response = await flow.begin_or_continue_flow(None) + + assert actual_response == expected_response + OAuthFlow.begin_flow.assert_not_called() + OAuthFlow.continue_flow.assert_not_called() diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py new file mode 100644 index 00000000..c7ef14eb --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/tools/mock_user_token_client.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import uuid +from datetime import datetime, timezone +from typing import Callable, List, Optional, Awaitable +from collections import deque + +from microsoft.agents.hosting.core.authorization import ClaimsIdentity +from microsoft.agents.activity import ( + Activity, + ActivityTypes, + ChannelAccount, + ConversationAccount, + ConversationReference, + Channels, + ResourceResponse, + RoleTypes, + InvokeResponse, +) +from microsoft.agents.hosting.core.channel_adapter import ChannelAdapter +from microsoft.agents.hosting.core.turn_context import TurnContext +from microsoft.agents.hosting.core.connector import UserTokenClient + +AgentCallbackHandler = Callable[[TurnContext], Awaitable] + + +# patch userTokenclient constructor +class MockUserTokenClient(UserTokenClient): + """A mock user token client for testing.""" + + def __init__(self): + self._store = {} + self._exchange_store = {} + self._throw_on_exchange = {} + + def add_user_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + token: str, + magic_code: str = None, + ): + """Add a token for a user that can be retrieved during testing.""" + key = self._get_key(connection_name, channel_id, user_id) + self._store[key] = (token, magic_code) + + def add_exchangeable_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + token: str, + ): + """Add an exchangeable token for a user that can be exchanged during testing.""" + key = self._get_exchange_key( + connection_name, channel_id, user_id, exchangeable_item + ) + self._exchange_store[key] = token + + def throw_on_exchange_request( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + ): + """Add an instruction to throw an exception during exchange requests.""" + key = self._get_exchange_key( + connection_name, channel_id, user_id, exchangeable_item + ) + self._throw_on_exchange[key] = True + + def _get_key(self, connection_name: str, channel_id: str, user_id: str) -> str: + return f"{connection_name}:{channel_id}:{user_id}" + + def _get_exchange_key( + self, + connection_name: str, + channel_id: str, + user_id: str, + exchangeable_item: str, + ) -> str: + return f"{connection_name}:{channel_id}:{user_id}:{exchangeable_item}" diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py index b3574b8f..e79486b5 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_adapter.py @@ -25,65 +25,7 @@ AgentCallbackHandler = Callable[[TurnContext], Awaitable] - -class MockUserTokenClient(UserTokenClient): - """A mock user token client for testing.""" - - def __init__(self): - self._store = {} - self._exchange_store = {} - self._throw_on_exchange = {} - - def add_user_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - token: str, - magic_code: str = None, - ): - """Add a token for a user that can be retrieved during testing.""" - key = self._get_key(connection_name, channel_id, user_id) - self._store[key] = (token, magic_code) - - def add_exchangeable_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - token: str, - ): - """Add an exchangeable token for a user that can be exchanged during testing.""" - key = self._get_exchange_key( - connection_name, channel_id, user_id, exchangeable_item - ) - self._exchange_store[key] = token - - def throw_on_exchange_request( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - ): - """Add an instruction to throw an exception during exchange requests.""" - key = self._get_exchange_key( - connection_name, channel_id, user_id, exchangeable_item - ) - self._throw_on_exchange[key] = True - - def _get_key(self, connection_name: str, channel_id: str, user_id: str) -> str: - return f"{connection_name}:{channel_id}:{user_id}" - - def _get_exchange_key( - self, - connection_name: str, - channel_id: str, - user_id: str, - exchangeable_item: str, - ) -> str: - return f"{connection_name}:{channel_id}:{user_id}:{exchangeable_item}" +from .mock_user_token_client import MockUserTokenClient class TestingAdapter(ChannelAdapter): diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py index ac90b6c6..9340f995 100644 --- a/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_authorization.py @@ -1,247 +1,248 @@ -""" -Testing utilities for authorization functionality - -This module provides mock implementations and helper classes for testing authorization, -authentication, and token management scenarios. It includes test doubles for token -providers, connection managers, and authorization handlers that can be configured -to simulate various authentication states and flow conditions. -""" - -from microsoft.agents.hosting.core import ( - Connections, - AccessTokenProviderBase, - AuthHandler, - Authorization, - MemoryStorage, - oauth_flow, -) -from typing import Dict, Union -from microsoft.agents.hosting.core.authorization.agent_auth_configuration import ( - AgentAuthConfiguration, -) -from microsoft.agents.hosting.core.authorization.claims_identity import ClaimsIdentity - -from microsoft.agents.activity import TokenResponse - -from unittest.mock import Mock, AsyncMock - - -def create_test_auth_handler( - name: str, obo: bool = False, title: str = None, text: str = None -): - """ - Creates a test AuthHandler instance with standardized connection names. - - This helper function simplifies the creation of AuthHandler objects for testing - by automatically generating connection names based on the provided name and - optionally including On-Behalf-Of (OBO) connection configuration. - - Args: - name: Base name for the auth handler, used to generate connection names - obo: Whether to include On-Behalf-Of connection configuration - title: Optional title for the auth handler - text: Optional descriptive text for the auth handler - - Returns: - AuthHandler: Configured auth handler instance with test-friendly connection names - """ - return AuthHandler( - name, - abs_oauth_connection_name=f"{name}-abs-connection", - obo_connection_name=f"{name}-obo-connection" if obo else None, - title=title, - text=text, - ) - - -class TestingTokenProvider(AccessTokenProviderBase): - """ - Access token provider for unit tests. - - This test double simulates an access token provider that returns predictable - token values based on the provider name. It implements both standard token - acquisition and On-Behalf-Of (OBO) token flows for comprehensive testing - of authentication scenarios. - """ - - def __init__(self, name: str): - """ - Initialize the testing token provider. - - Args: - name: Identifier used to generate predictable token values - """ - self.name = name - - async def get_access_token( - self, resource_url: str, scopes: list[str], force_refresh: bool = False - ) -> str: - """ - Get an access token for the specified resource and scopes. - - Returns a predictable token string based on the provider name for testing. - - Args: (unused in test implementation) - resource_url: URL of the resource requiring authentication - scopes: List of OAuth scopes requested - force_refresh: Whether to force token refresh - - Returns: - str: Test token in format "{name}-token" - """ - return f"{self.name}-token" - - async def aquire_token_on_behalf_of( - self, scopes: list[str], user_assertion: str - ) -> str: - """ - Acquire a token on behalf of another user (OBO flow). - - Returns a predictable OBO token string for testing scenarios involving - delegated permissions and token exchange. - - Args: (unused in test implementation) - scopes: List of OAuth scopes requested for the OBO token - user_assertion: JWT token representing the user's identity - - Returns: - str: Test OBO token in format "{name}-obo-token" - """ - return f"{self.name}-obo-token" - - -class TestingConnectionManager(Connections): - """ - Connection manager for unit tests. - - This test double provides a simplified connection management interface that - returns TestingTokenProvider instances for all connection requests. It enables - testing of authorization flows without requiring actual OAuth configurations - or external authentication services. - """ - - def get_connection(self, connection_name: str) -> AccessTokenProviderBase: - """ - Get a token provider for the specified connection name. - - Args: - connection_name: Name of the OAuth connection - - Returns: - AccessTokenProviderBase: TestingTokenProvider configured with the connection name - """ - return TestingTokenProvider(connection_name) - - def get_default_connection(self) -> AccessTokenProviderBase: - """ - Get the default token provider. - - Returns: - AccessTokenProviderBase: TestingTokenProvider configured with "default" name - """ - return TestingTokenProvider("default") - - def get_token_provider( - self, claims_identity: ClaimsIdentity, service_url: str - ) -> AccessTokenProviderBase: - """ - Get a token provider based on claims identity and service URL. - - In this test implementation, returns the default connection regardless - of the provided parameters. - - Args: (unused in test implementation) - claims_identity: User's claims and identity information - service_url: URL of the service requiring authentication - - Returns: - AccessTokenProviderBase: The default TestingTokenProvider - """ - return self.get_default_connection() - - def get_default_connection_configuration(self) -> AgentAuthConfiguration: - """ - Get the default authentication configuration. - - Returns: - AgentAuthConfiguration: Empty configuration suitable for testing - """ - return AgentAuthConfiguration() - - -class TestingAuthorization(Authorization): - """ - Authorization system for comprehensive unit testing. - - This test double extends the Authorization class to provide a fully mocked - authorization environment suitable for testing various authentication scenarios. - It automatically configures auth handlers with mock OAuth flows that can simulate - different states like successful authentication, failed sign-in, or in-progress flows. - """ - - def __init__( - self, - auth_handlers: Dict[str, AuthHandler], - token: Union[str, None] = "default", - flow_started=False, - sign_in_failed=False, - ): - """ - Initialize the testing authorization system. - - Sets up a complete test authorization environment with memory storage, - test connection manager, and configures all provided auth handlers with - mock OAuth flows. - - Args: - auth_handlers: Dictionary mapping handler names to AuthHandler instances - token: Token value to use in mock responses. "default" uses auto-generated - tokens, None simulates no token available, or provide custom jwt token string - flow_started: Simulate OAuth flows that have already started - sign_in_failed: Simulate failed sign-in attempts - """ - # Initialize with test-friendly components - storage = MemoryStorage() - connection_manager = TestingConnectionManager() - super().__init__( - storage=storage, - auth_handlers=auth_handlers, - connection_manager=connection_manager, - ) - - # Configure each auth handler with mock OAuth flow behavior - for auth_handler in self._auth_handlers.values(): - # Create default token response for this auth handler - default_token = TokenResponse( - connection_name=auth_handler.abs_oauth_connection_name, - token=f"{auth_handler.abs_oauth_connection_name}-token", - ) - - # Determine token response based on configuration - if token == "default": - token_response = default_token - elif token: - token_response = TokenResponse( - connection_name=auth_handler.abs_oauth_connection_name, - token=token, - ) - else: - token_response = None - - # Mock the OAuth flow with configurable behavior - auth_handler.flow = Mock( - get_user_token=AsyncMock(return_value=token_response), - _get_flow_state=AsyncMock( - # sign-in failed requires flow to be started - return_value=oauth_flow.FlowState( - flow_started=(flow_started or sign_in_failed) - ) - ), - begin_flow=AsyncMock(return_value=default_token), - # Mock flow continuation with optional failure simulation - continue_flow=AsyncMock( - return_value=None if sign_in_failed else default_token - ), - ) - - auth_handler.flow.flow_state = None +""" +Testing utilities for authorization functionality + +This module provides mock implementations and helper classes for testing authorization, +authentication, and token management scenarios. It includes test doubles for token +providers, connection managers, and authorization handlers that can be configured +to simulate various authentication states and flow conditions. +""" + +from microsoft.agents.hosting.core import ( + Connections, + AccessTokenProviderBase, + AuthHandler, + Authorization, + MemoryStorage, + OAuthFlow, +) +from typing import Dict, Union +from microsoft.agents.hosting.core.authorization.agent_auth_configuration import ( + AgentAuthConfiguration, +) +from microsoft.agents.hosting.core.authorization.claims_identity import ClaimsIdentity + +from microsoft.agents.activity import TokenResponse + +from unittest.mock import Mock, AsyncMock + + +def create_test_auth_handler( + name: str, obo: bool = False, title: str = None, text: str = None +): + """ + Creates a test AuthHandler instance with standardized connection names. + + This helper function simplifies the creation of AuthHandler objects for testing + by automatically generating connection names based on the provided name and + optionally including On-Behalf-Of (OBO) connection configuration. + + Args: + name: Base name for the auth handler, used to generate connection names + obo: Whether to include On-Behalf-Of connection configuration + title: Optional title for the auth handler + text: Optional descriptive text for the auth handler + + Returns: + AuthHandler: Configured auth handler instance with test-friendly connection names + """ + return AuthHandler( + name, + abs_oauth_connection_name=f"{name}-abs-connection", + obo_connection_name=f"{name}-obo-connection" if obo else None, + title=title, + text=text, + ) + + +class TestingTokenProvider(AccessTokenProviderBase): + """ + Access token provider for unit tests. + + This test double simulates an access token provider that returns predictable + token values based on the provider name. It implements both standard token + acquisition and On-Behalf-Of (OBO) token flows for comprehensive testing + of authentication scenarios. + """ + + def __init__(self, name: str): + """ + Initialize the testing token provider. + + Args: + name: Identifier used to generate predictable token values + """ + self.name = name + + async def get_access_token( + self, resource_url: str, scopes: list[str], force_refresh: bool = False + ) -> str: + """ + Get an access token for the specified resource and scopes. + + Returns a predictable token string based on the provider name for testing. + + Args: (unused in test implementation) + resource_url: URL of the resource requiring authentication + scopes: List of OAuth scopes requested + force_refresh: Whether to force token refresh + + Returns: + str: Test token in format "{name}-token" + """ + return f"{self.name}-token" + + async def aquire_token_on_behalf_of( + self, scopes: list[str], user_assertion: str + ) -> str: + """ + Acquire a token on behalf of another user (OBO flow). + + Returns a predictable OBO token string for testing scenarios involving + delegated permissions and token exchange. + + Args: (unused in test implementation) + scopes: List of OAuth scopes requested for the OBO token + user_assertion: JWT token representing the user's identity + + Returns: + str: Test OBO token in format "{name}-obo-token" + """ + return f"{self.name}-obo-token" + + +class TestingConnectionManager(Connections): + """ + Connection manager for unit tests. + + This test double provides a simplified connection management interface that + returns TestingTokenProvider instances for all connection requests. It enables + testing of authorization flows without requiring actual OAuth configurations + or external authentication services. + """ + + def get_connection(self, connection_name: str) -> AccessTokenProviderBase: + """ + Get a token provider for the specified connection name. + + Args: + connection_name: Name of the OAuth connection + + Returns: + AccessTokenProviderBase: TestingTokenProvider configured with the connection name + """ + return TestingTokenProvider(connection_name) + + def get_default_connection(self) -> AccessTokenProviderBase: + """ + Get the default token provider. + + Returns: + AccessTokenProviderBase: TestingTokenProvider configured with "default" name + """ + return TestingTokenProvider("default") + + def get_token_provider( + self, claims_identity: ClaimsIdentity, service_url: str + ) -> AccessTokenProviderBase: + """ + Get a token provider based on claims identity and service URL. + + In this test implementation, returns the default connection regardless + of the provided parameters. + + Args: (unused in test implementation) + claims_identity: User's claims and identity information + service_url: URL of the service requiring authentication + + Returns: + AccessTokenProviderBase: The default TestingTokenProvider + """ + return self.get_default_connection() + + def get_default_connection_configuration(self) -> AgentAuthConfiguration: + """ + Get the default authentication configuration. + + Returns: + AgentAuthConfiguration: Empty configuration suitable for testing + """ + return AgentAuthConfiguration() + + +class TestingAuthorization(Authorization): + """ + Authorization system for comprehensive unit testing. + + This test double extends the Authorization class to provide a fully mocked + authorization environment suitable for testing various authentication scenarios. + It automatically configures auth handlers with mock OAuth flows that can simulate + different states like successful authentication, failed sign-in, or in-progress flows. + """ + + def __init__( + self, + auth_handlers: Dict[str, AuthHandler], + token: Union[str, None] = "default", + flow_started=False, + sign_in_failed=False, + ): + """ + Initialize the testing authorization system. + + Sets up a complete test authorization environment with memory storage, + test connection manager, and configures all provided auth handlers with + mock OAuth flows. + + Args: + auth_handlers: Dictionary mapping handler names to AuthHandler instances + token: Token value to use in mock responses. "default" uses auto-generated + tokens, None simulates no token available, or provide custom jwt token string + flow_started: Simulate OAuth flows that have already started + sign_in_failed: Simulate failed sign-in attempts + """ + # Initialize with test-friendly components + storage = MemoryStorage() + connection_manager = TestingConnectionManager() + super().__init__( + storage=storage, + auth_handlers=auth_handlers, + connection_manager=connection_manager, + service_url="a", + ) + + # Configure each auth handler with mock OAuth flow behavior + for auth_handler in self._auth_handlers.values(): + # Create default token response for this auth handler + default_token = TokenResponse( + connection_name=auth_handler.abs_oauth_connection_name, + token=f"{auth_handler.abs_oauth_connection_name}-token", + ) + + # Determine token response based on configuration + if token == "default": + token_response = default_token + elif token: + token_response = TokenResponse( + connection_name=auth_handler.abs_oauth_connection_name, + token=token, + ) + else: + token_response = None + + # Mock the OAuth flow with configurable behavior + auth_handler.flow = Mock( + get_user_token=AsyncMock(return_value=token_response), + _get_flow_state=AsyncMock( + # sign-in failed requires flow to be started + return_value=oauth_flow.FlowState( + flow_started=(flow_started or sign_in_failed) + ) + ), + begin_flow=AsyncMock(return_value=default_token), + # Mock flow continuation with optional failure simulation + continue_flow=AsyncMock( + return_value=None if sign_in_failed else default_token + ), + ) + + auth_handler.flow.flow_state = None diff --git a/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py new file mode 100644 index 00000000..1dd8d799 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/tests/tools/testing_oauth.py @@ -0,0 +1,180 @@ +from datetime import datetime + +from microsoft.agents.hosting.core.storage.storage_test_utils import MockStoreItem +from microsoft.agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag + +MS_APP_ID = "__ms_app_id" +CHANNEL_ID = "__channel_id" +USER_ID = "__user_id" +ABS_OAUTH_CONNECTION_NAME = "__connection_name" +RES_TOKEN = "__res_token" + +DEF_ARGS = { + "ms_app_id": MS_APP_ID, + "channel_id": CHANNEL_ID, + "user_id": USER_ID, + "connection": ABS_OAUTH_CONNECTION_NAME, +} + + +class FLOW_STATES: + + NOT_STARTED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.NOT_STARTED, + attempts_remaining=1, + user_token="____", + expiration=datetime.now().timestamp() + 1000000, + ) + + STARTED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=1, + user_token="____", + expiration=datetime.now().timestamp() + 1000000, + ) + STARTED_FLOW_ONE_RETRY = FlowState( + **DEF_ARGS, + tag=FlowStateTag.BEGIN, + attempts_remaining=2, + user_token="____", + expiration=datetime.now().timestamp() + 1000000, + ) + ACTIVE_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expiration=datetime.now().timestamp() + 1000000, + ) + ACTIVE_FLOW_ONE_RETRY = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=1, + user_token="__token", + expiration=datetime.now().timestamp() + 1000000, + ) + ACTIVE_EXP_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.CONTINUE, + attempts_remaining=2, + user_token="__token", + expiration=datetime.now().timestamp(), + ) + COMPLETED_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.COMPLETE, + attempts_remaining=2, + user_token="test_token", + expiration=datetime.now().timestamp() + 1000000, + ) + FAIL_BY_ATTEMPTS_FLOW = FlowState( + **DEF_ARGS, + tag=FlowStateTag.FAILURE, + attempts_remaining=0, + expiration=datetime.now().timestamp() + 1000000, + ) + + FAIL_BY_EXP_FLOW = FlowState( + **DEF_ARGS, tag=FlowStateTag.FAILURE, attempts_remaining=2, expiration=0 + ) + + @staticmethod + def clone_state_list(lst): + return [flow_state.model_copy() for flow_state in lst] + + @staticmethod + def ALL(): + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW, + ] + ) + + @staticmethod + def FAILED(): + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW, + ] + ) + + @staticmethod + def ACTIVE(): + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.STARTED_FLOW, + FLOW_STATES.STARTED_FLOW_ONE_RETRY, + FLOW_STATES.ACTIVE_FLOW, + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY, + ] + ) + + @staticmethod + def INACTIVE(): + return FLOW_STATES.clone_state_list( + [ + FLOW_STATES.ACTIVE_EXP_FLOW, + FLOW_STATES.COMPLETED_FLOW, + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW, + FLOW_STATES.FAIL_BY_EXP_FLOW, + ] + ) + + +def flow_key(channel_id, user_id, handler_id): + return f"auth/{channel_id}/{user_id}/{handler_id}" + + +def update_flow_state_handler(flow_state, handler): + flow_state = flow_state.model_copy() + flow_state.auth_handler_id = handler + return flow_state + + +STORAGE_SAMPLE_DICT = { + "user_id": MockStoreItem({"id": "123"}), + "session_id": MockStoreItem({"id": "abc"}), + flow_key("webchat", "Alice", "graph"): update_flow_state_handler( + FLOW_STATES.COMPLETED_FLOW.model_copy(), "graph" + ), + flow_key("webchat", "Alice", "github"): update_flow_state_handler( + FLOW_STATES.ACTIVE_FLOW_ONE_RETRY.model_copy(), "github" + ), + flow_key("teams", "Alice", "graph"): update_flow_state_handler( + FLOW_STATES.STARTED_FLOW.model_copy(), "graph" + ), + flow_key("webchat", "Bob", "graph"): update_flow_state_handler( + FLOW_STATES.ACTIVE_EXP_FLOW.model_copy(), "graph" + ), + flow_key("teams", "Bob", "slack"): update_flow_state_handler( + FLOW_STATES.STARTED_FLOW.model_copy(), "slack" + ), + flow_key("webchat", "Chuck", "github"): update_flow_state_handler( + FLOW_STATES.FAIL_BY_ATTEMPTS_FLOW.model_copy(), "github" + ), +} + + +def STORAGE_INIT_DATA(): + data = STORAGE_SAMPLE_DICT.copy() + for key, value in data.items(): + data[key] = value.model_copy() if isinstance(value, FlowState) else value + return data + + +def update_data_with_flow_state(data, channel_id, user_id, auth_handler_id, flow_state): + data = data.copy() + key = f"auth/{channel_id}/{user_id}/{auth_handler_id}" + data[key] = flow_state.model_copy() + return data