diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/_load_configuration.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/_load_configuration.py index f3c6afa3..2a598f76 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/_load_configuration.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/_load_configuration.py @@ -1,7 +1,7 @@ -from typing import Any, Dict +from typing import Any -def load_configuration_from_env(env_vars: Dict[str, Any]) -> dict: +def load_configuration_from_env(env_vars: dict[str, Any]) -> dict: """ Parses environment variables and returns a dictionary with the relevant configuration. """ @@ -18,6 +18,11 @@ def load_configuration_from_env(env_vars: Dict[str, Any]) -> dict: current_level = current_level[next_level] last_level[levels[-1]] = value + if result.get("CONNECTIONSMAP") and isinstance(result["CONNECTIONSMAP"], dict): + result["CONNECTIONSMAP"] = [ + conn for conn in result.get("CONNECTIONSMAP", {}).values() + ] + return { "AGENTAPPLICATION": result.get("AGENTAPPLICATION", {}), "CONNECTIONS": result.get("CONNECTIONS", {}), diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py index 89730568..78e8ac7d 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py @@ -20,6 +20,7 @@ from .text_highlight import TextHighlight from .semantic_action import SemanticAction from .agents_model import AgentsModel +from .role_types import RoleTypes from ._model_utils import pick_model, SkipNone from ._type_aliases import NonEmptyString @@ -648,3 +649,21 @@ def add_ai_metadata( self.entities = [] self.entities.append(ai_entity) + + def is_agentic_request(self) -> bool: + return self.recipient and self.recipient.role in [ + RoleTypes.agentic_identity, + RoleTypes.agentic_user, + ] + + def get_agentic_instance_id(self) -> Optional[str]: + """Gets the agent instance ID from the context if it's an agentic request.""" + if not self.is_agentic_request() or not self.recipient: + return None + return self.recipient.agentic_app_id + + def get_agentic_user(self) -> Optional[str]: + """Gets the agentic user (UPN) from the context if it's an agentic request.""" + if not self.is_agentic_request() or not self.recipient: + return None + return self.recipient.id diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_account.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_account.py index 13b973d9..bf1db20c 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_account.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_account.py @@ -26,6 +26,9 @@ class ChannelAccount(AgentsModel): name: str = None aad_object_id: NonEmptyString = None role: NonEmptyString = None + agentic_user_id: NonEmptyString = None + agentic_app_id: NonEmptyString = None + tenant_id: NonEmptyString = None @property def properties(self) -> dict[str, Any]: diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/channels.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/channels.py index dbb47e62..d6172a43 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/channels.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/channels.py @@ -10,6 +10,9 @@ class Channels(str, Enum): Ids of channels supported by ABS. """ + """Agents channel.""" + agents = "agents" + console = "console" """Console channel.""" diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/role_types.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/role_types.py index 8064c371..1008cb8a 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/role_types.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/role_types.py @@ -5,3 +5,5 @@ class RoleTypes(str, Enum): user = "user" agent = "bot" skill = "skill" + agentic_identity = "agenticAppInstance" + agentic_user = "agenticUser" diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/token_response.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/token_response.py index 00d6aa91..e4cbd232 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/token_response.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/token_response.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import jwt + from .agents_model import AgentsModel from ._type_aliases import NonEmptyString @@ -26,3 +28,19 @@ class TokenResponse(AgentsModel): def __bool__(self): return bool(self.token) + + def is_exchangeable(self) -> bool: + """ + Checks if a token is exchangeable (has api:// audience). + + :param token: The token to check. + :type token: str + :return: True if the token is exchangeable, False otherwise. + """ + try: + # Decode without verification to check the audience + payload = jwt.decode(self.token, options={"verify_signature": False}) + aud = payload.get("aud") + return isinstance(aud, str) and aud.startswith("api://") + except Exception: + return False diff --git a/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py b/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py index eca444dd..abeb718c 100644 --- a/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py +++ b/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import jwt from typing import Optional from urllib.parse import urlparse, ParseResult as URI from msal import ( @@ -23,6 +24,18 @@ logger = logging.getLogger(__name__) +# this is deferred because jwt.decode is expensive and we don't want to do it unless we +# have logging.DEBUG enabled +class _DeferredLogOfBlueprintId: + def __init__(self, jwt_token: str): + self.jwt_token = jwt_token + + def __str__(self): + payload = jwt.decode(self.jwt_token, options={"verify_signature": False}) + agentic_blueprint_id = payload.get("xms_par_app_azp") + return f"Agentic blueprint id: {agentic_blueprint_id}" + + class MsalAuth(AccessTokenProviderBase): _client_credential_cache = None @@ -56,11 +69,16 @@ async def get_access_token( auth_result_payload = msal_auth_client.acquire_token_for_client( scopes=local_scopes ) + else: + auth_result_payload = None - # TODO: Handling token error / acquisition failed - return auth_result_payload["access_token"] + res = auth_result_payload.get("access_token") if auth_result_payload else None + if not res: + logger.error("Failed to acquire token for resource %s", auth_result_payload) + raise ValueError(f"Failed to acquire token. {str(auth_result_payload)}") + return res - async def aquire_token_on_behalf_of( + async def acquire_token_on_behalf_of( self, scopes: list[str], user_assertion: str ) -> str: """ @@ -186,3 +204,189 @@ def _resolve_scopes_list(self, instance_url: URI, scopes=None) -> list[str]: temp_list.append(scope_placeholder) logger.debug(f"Resolved scopes: {temp_list}") return temp_list + + # the call to MSAL is blocking, but in the future we want to create an asyncio task + # to avoid this + async def get_agentic_application_token( + self, agent_app_instance_id: str + ) -> Optional[str]: + """Gets the agentic application token for the given agent application instance ID. + + :param agent_app_instance_id: The agent application instance ID. + :type agent_app_instance_id: str + :return: The agentic application token, or None if not found. + :rtype: Optional[str] + """ + + if not agent_app_instance_id: + raise ValueError("Agent application instance Id must be provided.") + + logger.info( + "Attempting to get agentic application token from agent_app_instance_id %s", + agent_app_instance_id, + ) + msal_auth_client = self._create_client_application() + + if isinstance(msal_auth_client, ConfidentialClientApplication): + + # https://github.dev/AzureAD/microsoft-authentication-library-for-dotnet + auth_result_payload = msal_auth_client.acquire_token_for_client( + ["api://AzureAdTokenExchange/.default"], + data={"fmi_path": agent_app_instance_id}, + ) + + if auth_result_payload: + return auth_result_payload.get("access_token") + + return None + + async def get_agentic_instance_token( + self, agent_app_instance_id: str + ) -> tuple[str, str]: + """Gets the agentic instance token for the given agent application instance ID. + + :param agent_app_instance_id: The agent application instance ID. + :type agent_app_instance_id: str + :return: A tuple containing the agentic instance token and the agent application token. + :rtype: tuple[str, str] + """ + + if not agent_app_instance_id: + raise ValueError("Agent application instance Id must be provided.") + + logger.info( + "Attempting to get agentic instance token from agent_app_instance_id %s", + agent_app_instance_id, + ) + agent_token_result = await self.get_agentic_application_token( + agent_app_instance_id + ) + + if not agent_token_result: + logger.error( + "Failed to acquire agentic instance token or agent token for agent_app_instance_id %s", + agent_app_instance_id, + ) + raise Exception( + f"Failed to acquire agentic instance token or agent token for agent_app_instance_id {agent_app_instance_id}" + ) + + authority = ( + f"https://login.microsoftonline.com/{self._msal_configuration.TENANT_ID}" + ) + + instance_app = ConfidentialClientApplication( + client_id=agent_app_instance_id, + authority=authority, + client_credential={"client_assertion": agent_token_result}, + ) + + agentic_instance_token = instance_app.acquire_token_for_client( + ["api://AzureAdTokenExchange/.default"] + ) + + if not agentic_instance_token: + logger.error( + "Failed to acquire agentic instance token or agent token for agent_app_instance_id %s", + agent_app_instance_id, + ) + raise Exception( + f"Failed to acquire agentic instance token or agent token for agent_app_instance_id {agent_app_instance_id}" + ) + + # future scenario where we don't know the blueprint id upfront + + token = agentic_instance_token.get("access_token") + if not token: + logger.error( + "Failed to acquire agentic instance token, %s", agentic_instance_token + ) + raise ValueError(f"Failed to acquire token. {str(agentic_instance_token)}") + + logger.debug(_DeferredLogOfBlueprintId(token)) + + return agentic_instance_token["access_token"], agent_token_result + + async def get_agentic_user_token( + self, agent_app_instance_id: str, upn: str, scopes: list[str] + ) -> Optional[str]: + """Gets the agentic user token for the given agent application instance ID and user principal name and the scopes. + + :param agent_app_instance_id: The agent application instance ID. + :type agent_app_instance_id: str + :param upn: The user principal name. + :type upn: str + :param scopes: The scopes to request for the token. + :type scopes: list[str] + :return: The agentic user token, or None if not found. + :rtype: Optional[str] + """ + if not agent_app_instance_id or not upn: + raise ValueError( + "Agent application instance Id and user principal name must be provided." + ) + + logger.info( + "Attempting to get agentic user token from agent_app_instance_id %s and upn %s", + agent_app_instance_id, + upn, + ) + instance_token, agent_token = await self.get_agentic_instance_token( + agent_app_instance_id + ) + + if not instance_token or not agent_token: + logger.error( + "Failed to acquire instance token or agent token for agent_app_instance_id %s and upn %s", + agent_app_instance_id, + upn, + ) + raise Exception( + f"Failed to acquire instance token or agent token for agent_app_instance_id {agent_app_instance_id} and upn {upn}" + ) + + authority = ( + f"https://login.microsoftonline.com/{self._msal_configuration.TENANT_ID}" + ) + + instance_app = ConfidentialClientApplication( + client_id=agent_app_instance_id, + authority=authority, + client_credential={"client_assertion": agent_token}, + ) + + logger.info( + "Acquiring agentic user token for agent_app_instance_id %s and upn %s", + agent_app_instance_id, + upn, + ) + auth_result_payload = instance_app.acquire_token_for_client( + scopes, + data={ + "username": upn, + "user_federated_identity_credential": instance_token, + "grant_type": "user_fic", + }, + ) + + if not auth_result_payload: + logger.error( + "Failed to acquire agentic user token for agent_app_instance_id %s and upn %s, %s", + agent_app_instance_id, + upn, + auth_result_payload, + ) + return None + + access_token = auth_result_payload.get("access_token") + if not access_token: + logger.error( + "Failed to acquire agentic user token for agent_app_instance_id %s and upn %s, %s", + agent_app_instance_id, + upn, + auth_result_payload, + ) + return None + + logger.info("Acquired agentic user token response.") + return access_token diff --git a/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_connection_manager.py b/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_connection_manager.py index 597f0b1c..aea4163a 100644 --- a/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_connection_manager.py +++ b/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_connection_manager.py @@ -1,3 +1,4 @@ +import re from typing import Dict, List, Optional from microsoft_agents.hosting.core import ( AgentAuthConfiguration, @@ -10,15 +11,28 @@ class MsalConnectionManager(Connections): + _connections: Dict[str, MsalAuth] + _connections_map: List[Dict[str, str]] + _service_connection_configuration: AgentAuthConfiguration def __init__( self, - connections_configurations: Dict[str, AgentAuthConfiguration] = None, - connections_map: List[Dict[str, str]] = None, - **kwargs + connections_configurations: Optional[Dict[str, AgentAuthConfiguration]] = None, + connections_map: Optional[List[Dict[str, str]]] = None, + **kwargs, ): + """ + Initialize the MSAL connection manager. + + :arg connections_configurations: A dictionary of connection configurations. + :type connections_configurations: Dict[str, AgentAuthConfiguration] + :arg connections_map: A list of connection mappings. + :type connections_map: List[Dict[str, str]] + :raises ValueError: If no service connection configuration is provided. + """ + self._connections: Dict[str, MsalAuth] = {} - self._connections_map = connections_map or kwargs.get("CONNECTIONS_MAP", {}) + self._connections_map = connections_map or kwargs.get("CONNECTIONSMAP", {}) self._service_connection_configuration: AgentAuthConfiguration = None if connections_configurations: @@ -45,13 +59,20 @@ def __init__( def get_connection(self, connection_name: Optional[str]) -> AccessTokenProviderBase: """ Get the OAuth connection for the agent. + + :arg connection_name: The name of the connection. + :type connection_name: str + :return: The OAuth connection for the agent. + :rtype: AccessTokenProviderBase """ + # should never be None return self._connections.get(connection_name, None) def get_default_connection(self) -> AccessTokenProviderBase: """ Get the default OAuth connection for the agent. """ + # should never be None return self._connections.get("SERVICE_CONNECTION", None) def get_token_provider( @@ -59,11 +80,49 @@ def get_token_provider( ) -> AccessTokenProviderBase: """ Get the OAuth token provider for the agent. + + :arg claims_identity: The claims identity of the bot. + :type claims_identity: ClaimsIdentity + :arg service_url: The service URL of the bot. + :type service_url: str + :return: The OAuth token provider for the agent. + :rtype: AccessTokenProviderBase + :raises ValueError: If no connection is found for the given audience and service URL. """ + if not claims_identity or not service_url: + raise ValueError( + "Claims identity and Service URL are required to get the token provider." + ) + if not self._connections_map: return self.get_default_connection() - # TODO: Implement logic to select the appropriate connection based on the connection map + aud = claims_identity.get_app_id() or "" + for item in self._connections_map: + audience_match = True + item_aud = item.get("AUDIENCE", "") + if item_aud: + audience_match = item_aud.lower() == aud.lower() + + if audience_match: + item_service_url = item.get("SERVICEURL", "") + if item_service_url == "*" or item_service_url == "": + connection_name = item.get("CONNECTION") + connection = self.get_connection(connection_name) + if connection: + return connection + + else: + res = re.match(item_service_url, service_url, re.IGNORECASE) + if res: + connection_name = item.get("CONNECTION") + connection = self.get_connection(connection_name) + if connection: + return connection + + raise ValueError( + f"No connection found for audience '{aud}' and serviceUrl '{service_url}'." + ) def get_default_connection_configuration(self) -> AgentAuthConfiguration: """ diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py index e80014bf..1ef106c3 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py @@ -82,7 +82,11 @@ async def process(self, request: Request, agent: Agent) -> Optional[Response]: raise HTTPUnsupportedMediaType() activity: Activity = Activity.model_validate(body) - claims_identity: ClaimsIdentity = request.get("claims_identity") + + # default to anonymous identity with no claims + claims_identity: ClaimsIdentity = request.get( + "claims_identity", ClaimsIdentity({}, False) + ) # A POST request must contain an Activity if ( diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py index 5accb9f7..d28618cd 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py @@ -9,6 +9,7 @@ @middleware async def jwt_authorization_middleware(request: Request, handler): + auth_config: AgentAuthConfiguration = request.app["agent_configuration"] token_validator = JwtTokenValidator(auth_config) auth_header = request.headers.get("Authorization") diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py index f5d07cef..50c990c8 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py @@ -22,8 +22,8 @@ # App Auth from .app.oauth import ( Authorization, - AuthorizationHandlers, AuthHandler, + AgenticUserAuthorization, ) # App State @@ -42,16 +42,6 @@ from .authorization.jwt_token_validator import JwtTokenValidator from .authorization.auth_types import AuthTypes -# OAuth -from .oauth import ( - FlowState, - FlowStateTag, - FlowErrorTag, - FlowResponse, - FlowStorageClient, - OAuthFlow, -) - # Client API from .client.agent_conversation_reference import AgentConversationReference from .client.channel_factory_protocol import ChannelFactoryProtocol @@ -105,15 +95,11 @@ "Middleware", "RestChannelServiceClientFactory", "TurnContext", - "ActivityType", "AgentApplication", "ApplicationError", "ApplicationOptions", - "ConversationUpdateType", "InputFile", "InputFileDownloader", - "MessageReactionType", - "MessageUpdateType", "Query", "Route", "RouteHandler", @@ -124,9 +110,7 @@ "TurnState", "TempState", "Authorization", - "AuthorizationHandlers", "AuthHandler", - "SignInState", "AccessTokenProviderBase", "AuthenticationConstants", "AnonymousTokenProvider", @@ -134,7 +118,6 @@ "AgentAuthConfiguration", "ClaimsIdentity", "JwtTokenValidator", - "AuthTypes", "AgentConversationReference", "ChannelFactoryProtocol", "ChannelHostProtocol", @@ -162,10 +145,6 @@ "StoreItem", "Storage", "MemoryStorage", - "FlowState", - "FlowStateTag", - "FlowErrorTag", - "FlowResponse", - "FlowStorageClient", - "OAuthFlow", + "AgenticUserAuthorization", + "Authorization", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/__init__.py new file mode 100644 index 00000000..c9b319e6 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/__init__.py @@ -0,0 +1,12 @@ +from ._flow_state import _FlowState, _FlowStateTag, _FlowErrorTag +from ._flow_storage_client import _FlowStorageClient +from ._oauth_flow import _OAuthFlow, _FlowResponse + +__all__ = [ + "_FlowState", + "_FlowStateTag", + "_FlowErrorTag", + "_FlowResponse", + "_FlowStorageClient", + "_OAuthFlow", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_flow_state.py similarity index 76% rename from libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py rename to libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_flow_state.py index efeb7cb2..50572947 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_flow_state.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations + from datetime import datetime from enum import Enum from typing import Optional @@ -12,7 +14,7 @@ from ..storage import StoreItem -class FlowStateTag(Enum): +class _FlowStateTag(Enum): """Represents the top-level state of an OAuthFlow For instance, a flow can arrive at an error, but its @@ -27,7 +29,7 @@ class FlowStateTag(Enum): COMPLETE = "complete" -class FlowErrorTag(Enum): +class _FlowErrorTag(Enum): """Represents the various error states that can occur during an OAuthFlow""" NONE = "none" @@ -36,11 +38,9 @@ class FlowErrorTag(Enum): OTHER = "other" -class FlowState(BaseModel, StoreItem): +class _FlowState(BaseModel, StoreItem): """Represents the state of an OAuthFlow""" - user_token: str = "" - channel_id: str = "" user_id: str = "" ms_app_id: str = "" @@ -50,14 +50,14 @@ class FlowState(BaseModel, StoreItem): expiration: float = 0 continuation_activity: Optional[Activity] = None attempts_remaining: int = 0 - tag: FlowStateTag = FlowStateTag.NOT_STARTED + tag: _FlowStateTag = _FlowStateTag.NOT_STARTED def store_item_to_json(self) -> dict: return self.model_dump(mode="json", exclude_unset=True, by_alias=True) @staticmethod - def from_json_to_store_item(json_data: dict) -> "FlowState": - return FlowState.model_validate(json_data) + def from_json_to_store_item(json_data: dict) -> _FlowState: + return _FlowState.model_validate(json_data) def is_expired(self) -> bool: return datetime.now().timestamp() >= self.expiration @@ -69,13 +69,13 @@ def is_active(self) -> bool: return ( not self.is_expired() and not self.reached_max_attempts() - and self.tag in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE] + and self.tag in [_FlowStateTag.BEGIN, _FlowStateTag.CONTINUE] ) def refresh(self): if ( self.tag - in [FlowStateTag.BEGIN, FlowStateTag.CONTINUE, FlowStateTag.COMPLETE] + in [_FlowStateTag.BEGIN, _FlowStateTag.CONTINUE, _FlowStateTag.COMPLETE] and self.is_expired() ): - self.tag = FlowStateTag.NOT_STARTED + self.tag = _FlowStateTag.NOT_STARTED diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_storage_client.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_flow_storage_client.py similarity index 82% rename from libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_storage_client.py rename to libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_flow_storage_client.py index 7ab03879..867b3aa6 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_storage_client.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_flow_storage_client.py @@ -4,15 +4,15 @@ from typing import Optional from ..storage import Storage -from .flow_state import FlowState +from ._flow_state import _FlowState -class DummyCache(Storage): +class _DummyCache(Storage): - async def read(self, keys: list[str], **kwargs) -> dict[str, FlowState]: + async def read(self, keys: list[str], **kwargs) -> dict[str, _FlowState]: return {} - async def write(self, changes: dict[str, FlowState]) -> None: + async def write(self, changes: dict[str, _FlowState]) -> None: pass async def delete(self, keys: list[str]) -> None: @@ -23,7 +23,7 @@ async def delete(self, keys: list[str]) -> None: # - CachedStorage class for two-tier storage # - Namespaced/PrefixedStorage class for namespacing keying # not generally thread or async safe (operations are not atomic) -class FlowStorageClient: +class _FlowStorageClient: """Wrapper around Storage that manages sign-in state specific to each user and channel. Uses the activity's channel_id and from.id to create a key prefix for storage operations. @@ -34,7 +34,7 @@ def __init__( channel_id: str, user_id: str, storage: Storage, - cache_class: type[Storage] = None, + cache_class: Optional[type[Storage]] = None, ): """ Args: @@ -53,7 +53,7 @@ def __init__( self._base_key = f"auth/{channel_id}/{user_id}/" self._storage = storage if cache_class is None: - cache_class = DummyCache + cache_class = _DummyCache self._cache = cache_class() @property @@ -65,21 +65,21 @@ def key(self, auth_handler_id: str) -> str: """Creates a storage key for a specific sign-in handler.""" return f"{self._base_key}{auth_handler_id}" - async def read(self, auth_handler_id: str) -> Optional[FlowState]: + async def read(self, auth_handler_id: str) -> Optional[_FlowState]: """Reads the flow state for a specific authentication handler.""" key: str = self.key(auth_handler_id) - data = await self._cache.read([key], target_cls=FlowState) + data = await self._cache.read([key], target_cls=_FlowState) if key not in data: - data = await self._storage.read([key], target_cls=FlowState) + data = await self._storage.read([key], target_cls=_FlowState) if key not in data: return None await self._cache.write({key: data[key]}) - return FlowState.model_validate(data.get(key)) + return _FlowState.model_validate(data.get(key)) - async def write(self, value: FlowState) -> None: + async def write(self, value: _FlowState) -> None: """Saves the flow state for a specific authentication handler.""" key: str = self.key(value.auth_handler_id) - cached_state = await self._cache.read([key], target_cls=FlowState) + cached_state = await self._cache.read([key], target_cls=_FlowState) if not cached_state or cached_state != value: await self._cache.write({key: value}) await self._storage.write({key: value}) @@ -87,7 +87,7 @@ async def write(self, value: FlowState) -> None: async def delete(self, auth_handler_id: str) -> None: """Deletes the flow state for a specific authentication handler.""" key: str = self.key(auth_handler_id) - cached_state = await self._cache.read([key], target_cls=FlowState) + cached_state = await self._cache.read([key], target_cls=_FlowState) if cached_state: await self._cache.delete([key]) await self._storage.delete([key]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py similarity index 84% rename from libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py rename to libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py index 3a12b890..a3a9c808 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/_oauth/_oauth_flow.py @@ -18,22 +18,22 @@ ) from ..connector.client import UserTokenClient -from .flow_state import FlowState, FlowStateTag, FlowErrorTag +from ._flow_state import _FlowState, _FlowStateTag, _FlowErrorTag logger = logging.getLogger(__name__) -class FlowResponse(BaseModel): +class _FlowResponse(BaseModel): """Represents the response for a flow operation.""" - flow_state: FlowState = FlowState() - flow_error_tag: FlowErrorTag = FlowErrorTag.NONE + flow_state: _FlowState = _FlowState() + flow_error_tag: _FlowErrorTag = _FlowErrorTag.NONE token_response: Optional[TokenResponse] = None sign_in_resource: Optional[SignInResource] = None continuation_activity: Optional[Activity] = None -class OAuthFlow: +class _OAuthFlow: """ Manages the OAuth flow. @@ -48,7 +48,7 @@ class OAuthFlow: """ def __init__( - self, flow_state: FlowState, user_token_client: UserTokenClient, **kwargs + self, flow_state: _FlowState, user_token_client: UserTokenClient, **kwargs ): """ Arguments: @@ -105,7 +105,7 @@ def __init__( ) @property - def flow_state(self) -> FlowState: + def flow_state(self) -> _FlowState: return self._flow_state.model_copy() async def get_user_token(self, magic_code: str = None) -> TokenResponse: @@ -136,11 +136,10 @@ async def get_user_token(self, magic_code: str = None) -> TokenResponse: ) if token_response: logger.info("User token obtained successfully: %s", token_response) - self._flow_state.user_token = token_response.token self._flow_state.expiration = ( datetime.now().timestamp() + self._default_flow_duration ) - self._flow_state.tag = FlowStateTag.COMPLETE + self._flow_state.tag = _FlowStateTag.COMPLETE return token_response @@ -160,20 +159,19 @@ async def sign_out(self) -> None: connection_name=self._abs_oauth_connection_name, channel_id=self._channel_id, ) - self._flow_state.user_token = "" - self._flow_state.tag = FlowStateTag.NOT_STARTED + self._flow_state.tag = _FlowStateTag.NOT_STARTED def _use_attempt(self) -> None: """Decrements the remaining attempts for the flow, checking for failure.""" self._flow_state.attempts_remaining -= 1 if self._flow_state.attempts_remaining <= 0: - self._flow_state.tag = FlowStateTag.FAILURE + self._flow_state.tag = _FlowStateTag.FAILURE logger.debug( "Using an attempt for the OAuth flow. Attempts remaining after use: %d", self._flow_state.attempts_remaining, ) - async def begin_flow(self, activity: Activity) -> FlowResponse: + async def begin_flow(self, activity: Activity) -> _FlowResponse: """Begins the OAuthFlow. Args: @@ -187,18 +185,17 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: """ token_response = await self.get_user_token() if token_response: - return FlowResponse( + return _FlowResponse( flow_state=self._flow_state, token_response=token_response ) logger.debug("Starting new OAuth flow") - self._flow_state.tag = FlowStateTag.BEGIN + self._flow_state.tag = _FlowStateTag.BEGIN self._flow_state.expiration = ( datetime.now().timestamp() + self._default_flow_duration ) self._flow_state.attempts_remaining = self._max_attempts - self._flow_state.user_token = "" self._flow_state.continuation_activity = activity.model_copy() token_exchange_state = TokenExchangeState( @@ -216,24 +213,24 @@ async def begin_flow(self, activity: Activity) -> FlowResponse: logger.debug("Sign-in resource obtained successfully: %s", sign_in_resource) - return FlowResponse( + return _FlowResponse( flow_state=self._flow_state, sign_in_resource=sign_in_resource ) async def _continue_from_message( self, activity: Activity - ) -> tuple[TokenResponse, FlowErrorTag]: + ) -> tuple[TokenResponse, _FlowErrorTag]: """Handles the continuation of the flow from a message activity.""" magic_code: str = activity.text if magic_code and magic_code.isdigit() and len(magic_code) == 6: token_response: TokenResponse = await self.get_user_token(magic_code) if token_response: - return token_response, FlowErrorTag.NONE + return token_response, _FlowErrorTag.NONE else: - return token_response, FlowErrorTag.MAGIC_CODE_INCORRECT + return token_response, _FlowErrorTag.MAGIC_CODE_INCORRECT else: - return TokenResponse(), FlowErrorTag.MAGIC_FORMAT + return TokenResponse(), _FlowErrorTag.MAGIC_FORMAT async def _continue_from_invoke_verify_state( self, activity: Activity @@ -257,7 +254,7 @@ async def _continue_from_invoke_token_exchange( ) return token_response - async def continue_flow(self, activity: Activity) -> FlowResponse: + async def continue_flow(self, activity: Activity) -> _FlowResponse: """Continues the OAuth flow based on the incoming activity. Args: @@ -271,12 +268,12 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: if not self._flow_state.is_active(): logger.debug("OAuth flow is not active, cannot continue") - self._flow_state.tag = FlowStateTag.FAILURE - return FlowResponse( + self._flow_state.tag = _FlowStateTag.FAILURE + return _FlowResponse( flow_state=self._flow_state.model_copy(), token_response=None ) - flow_error_tag = FlowErrorTag.NONE + flow_error_tag = _FlowErrorTag.NONE if activity.type == ActivityTypes.message: token_response, flow_error_tag = await self._continue_from_message(activity) elif ( @@ -292,32 +289,31 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: else: raise ValueError(f"Unknown activity type {activity.type}") - if not token_response and flow_error_tag == FlowErrorTag.NONE: - flow_error_tag = FlowErrorTag.OTHER + if not token_response and flow_error_tag == _FlowErrorTag.NONE: + flow_error_tag = _FlowErrorTag.OTHER - if flow_error_tag != FlowErrorTag.NONE: + if flow_error_tag != _FlowErrorTag.NONE: logger.debug("Flow error occurred: %s", flow_error_tag) - self._flow_state.tag = FlowStateTag.CONTINUE + self._flow_state.tag = _FlowStateTag.CONTINUE self._use_attempt() else: - self._flow_state.tag = FlowStateTag.COMPLETE + self._flow_state.tag = _FlowStateTag.COMPLETE self._flow_state.expiration = ( datetime.now().timestamp() + self._default_flow_duration ) - self._flow_state.user_token = token_response.token logger.debug( "OAuth flow completed successfully, got TokenResponse: %s", token_response, ) - return FlowResponse( + return _FlowResponse( flow_state=self._flow_state.model_copy(), flow_error_tag=flow_error_tag, token_response=token_response, continuation_activity=self._flow_state.continuation_activity, ) - async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: + async def begin_or_continue_flow(self, activity: Activity) -> _FlowResponse: """Begins a new OAuth flow or continues an existing one based on the activity. Args: @@ -327,12 +323,6 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: A FlowResponse object containing the updated flow state and any token response. """ self._flow_state.refresh() - if self._flow_state.tag == FlowStateTag.COMPLETE: # robrandao: TODO -> test - logger.debug("OAuth flow has already been completed, nothing to do") - return FlowResponse( - flow_state=self._flow_state.model_copy(), - token_response=TokenResponse(token=self._flow_state.user_token), - ) if self._flow_state.is_active(): logger.debug("Active flow, continuing...") diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/__init__.py index 4089c3fb..0cf00fc4 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/__init__.py @@ -17,7 +17,7 @@ from .oauth import ( Authorization, AuthHandler, - AuthorizationHandlers, + AgenticUserAuthorization, ) # App State @@ -27,15 +27,11 @@ from .state.turn_state import TurnState __all__ = [ - "ActivityType", "AgentApplication", "ApplicationError", "ApplicationOptions", - "ConversationUpdateType", "InputFile", "InputFileDownloader", - "MessageReactionType", - "MessageUpdateType", "Query", "Route", "RouteHandler", @@ -49,5 +45,5 @@ "TempState", "Authorization", "AuthHandler", - "AuthorizationHandlers", + "AgenticUserAuthorization", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/agent_application.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/agent_application.py index 01e5bb7b..4bf974a7 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/agent_application.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/agent_application.py @@ -23,36 +23,24 @@ cast, ) -from microsoft_agents.hosting.core.authorization import Connections - -from microsoft_agents.hosting.core import Agent, TurnContext from microsoft_agents.activity import ( Activity, ActivityTypes, - ActionTypes, ConversationUpdateTypes, MessageReactionTypes, MessageUpdateTypes, InvokeResponse, - TokenResponse, - OAuthCard, - Attachment, - CardAction, ) -from .. import CardFactory, MessageFactory +from ..turn_context import TurnContext +from ..agent import Agent +from ..authorization import Connections from .app_error import ApplicationError from .app_options import ApplicationOptions -# from .auth import AuthManager, OAuth, OAuthOptions from .route import Route, RouteHandler from .state import TurnState from ..channel_service_adapter import ChannelServiceAdapter -from ..oauth import ( - FlowResponse, - FlowState, - FlowStateTag, -) from .oauth import Authorization from .typing_indicator import TypingIndicator @@ -191,7 +179,7 @@ def adapter(self) -> ChannelServiceAdapter: return self._adapter @property - def auth(self): + def auth(self) -> Authorization: """ The application's authentication manager """ @@ -603,103 +591,6 @@ def turn_state_factory(self, func: Callable[[TurnContext], Awaitable[StateT]]): self._turn_state_factory = func return func - async def _handle_flow_response( - self, context: TurnContext, flow_response: FlowResponse - ) -> None: - """Handles CONTINUE and FAILURE flow responses, sending activities back.""" - flow_state: FlowState = flow_response.flow_state - - if flow_state.tag == FlowStateTag.BEGIN: - # Create the OAuth card - sign_in_resource = flow_response.sign_in_resource - o_card: Attachment = CardFactory.oauth_card( - OAuthCard( - text="Sign in", - connection_name=flow_state.connection, - buttons=[ - CardAction( - title="Sign in", - type=ActionTypes.signin, - value=sign_in_resource.sign_in_link, - channel_data=None, - ) - ], - token_exchange_resource=sign_in_resource.token_exchange_resource, - token_post_resource=sign_in_resource.token_post_resource, - ) - ) - # Send the card to the user - await context.send_activity(MessageFactory.attachment(o_card)) - elif flow_state.tag == FlowStateTag.FAILURE: - if flow_state.reached_max_attempts(): - await context.send_activity( - MessageFactory.text( - "Sign-in failed. Max retries reached. Please try again later." - ) - ) - elif flow_state.is_expired(): - await context.send_activity( - MessageFactory.text("Sign-in session expired. Please try again.") - ) - else: - logger.warning("Sign-in flow failed for unknown reasons.") - await context.send_activity("Sign-in failed. Please try again.") - - async def _on_turn_auth_intercept( - self, context: TurnContext, turn_state: TurnState - ) -> bool: - """Intercepts the turn to check for active authentication flows.""" - logger.debug( - "Checking for active sign-in flow for context: %s with activity type %s", - context.activity.id, - context.activity.type, - ) - prev_flow_state = await self._auth.get_active_flow_state(context) - if prev_flow_state: - logger.debug( - "Previous flow state: %s", - { - "user_id": prev_flow_state.user_id, - "connection": prev_flow_state.connection, - "channel_id": prev_flow_state.channel_id, - "auth_handler_id": prev_flow_state.auth_handler_id, - "tag": prev_flow_state.tag, - "expiration": prev_flow_state.expiration, - }, - ) - # proceed if there is an existing flow to continue - # new flows should be initiated in _on_activity - # this can be reorganized later... but it works for now - if ( - prev_flow_state - and ( - prev_flow_state.tag == FlowStateTag.NOT_STARTED - or prev_flow_state.is_active() - ) - and context.activity.type in [ActivityTypes.message, ActivityTypes.invoke] - ): - - logger.debug("Sign-in flow is active for context: %s", context.activity.id) - - flow_response: FlowResponse = await self._auth.begin_or_continue_flow( - context, turn_state, prev_flow_state.auth_handler_id - ) - - await self._handle_flow_response(context, flow_response) - - new_flow_state: FlowState = flow_response.flow_state - token_response: TokenResponse = flow_response.token_response - saved_activity: Activity = new_flow_state.continuation_activity.model_copy() - - if token_response: - new_context = copy(context) - new_context.activity = saved_activity - logger.info("Resending continuation activity %s", saved_activity.text) - await self.on_turn(new_context) - await turn_state.save(context) - return True # early return from _on_turn - return False # continue _on_turn - async def on_turn(self, context: TurnContext): logger.debug( f"AgentApplication.on_turn(): Processing turn for context: {context.activity.id}" @@ -715,9 +606,26 @@ async def _on_turn(self, context: TurnContext): logger.debug("Initializing turn state") turn_state = await self._initialize_state(context) - - if self._auth and await self._on_turn_auth_intercept(context, turn_state): - return + if ( + context.activity.type == ActivityTypes.message + or context.activity.type == ActivityTypes.invoke + ): + + ( + auth_intercepts, + continuation_activity, + ) = await self._auth._on_turn_auth_intercept(context, turn_state) + if auth_intercepts: + if continuation_activity: + new_context = copy(context) + new_context.activity = continuation_activity + logger.info( + "Resending continuation activity %s", + continuation_activity.text, + ) + await self.on_turn(new_context) + await turn_state.save(context) + return logger.debug("Running before turn middleware") if not await self._run_before_turn_middleware(context, turn_state): @@ -834,26 +742,14 @@ async def _on_activity(self, context: TurnContext, state: StateT): if not route.auth_handlers: await route.handler(context, state) else: - sign_in_complete = False + sign_in_complete = True for auth_handler_id in route.auth_handlers: - logger.debug( - "Beginning or continuing flow for auth handler %s", - auth_handler_id, - ) - flow_response: FlowResponse = ( - await self._auth.begin_or_continue_flow( + if not ( + await self._auth._start_or_continue_sign_in( context, state, auth_handler_id ) - ) - await self._handle_flow_response(context, flow_response) - logger.debug( - "Flow response flow_state.tag: %s", - flow_response.flow_state.tag, - ) - sign_in_complete = ( - flow_response.flow_state.tag == FlowStateTag.COMPLETE - ) - if not sign_in_complete: + ).sign_in_complete(): + sign_in_complete = False break if sign_in_complete: diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/__init__.py index 7c962a43..7fe3948d 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/__init__.py @@ -1,8 +1,19 @@ from .authorization import Authorization -from .auth_handler import AuthHandler, AuthorizationHandlers +from .auth_handler import AuthHandler +from ._sign_in_state import _SignInState +from ._sign_in_response import _SignInResponse +from ._handlers import ( + _UserAuthorization, + AgenticUserAuthorization, + _AuthorizationHandler, +) __all__ = [ "Authorization", "AuthHandler", - "AuthorizationHandlers", + "_AuthorizationHandler", + "_SignInState", + "_SignInResponse", + "_UserAuthorization", + "AgenticUserAuthorization", ] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/__init__.py new file mode 100644 index 00000000..05cf6dba --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/__init__.py @@ -0,0 +1,9 @@ +from .agentic_user_authorization import AgenticUserAuthorization +from ._user_authorization import _UserAuthorization +from ._authorization_handler import _AuthorizationHandler + +__all__ = [ + "AgenticUserAuthorization", + "_UserAuthorization", + "_AuthorizationHandler", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/_authorization_handler.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/_authorization_handler.py new file mode 100644 index 00000000..eba18b5d --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/_authorization_handler.py @@ -0,0 +1,104 @@ +from abc import ABC +from typing import Optional +import logging + +from microsoft_agents.activity import TokenResponse + +from ....turn_context import TurnContext +from ....storage import Storage +from ....authorization import Connections +from ..auth_handler import AuthHandler +from .._sign_in_response import _SignInResponse + +logger = logging.getLogger(__name__) + + +class _AuthorizationHandler(ABC): + """Base class for different authorization strategies.""" + + _storage: Storage + _connection_manager: Connections + _handler: AuthHandler + + def __init__( + self, + storage: Storage, + connection_manager: Connections, + auth_handler: Optional[AuthHandler] = None, + *, + auth_handler_id: Optional[str] = None, + auth_handler_settings: Optional[dict] = None, + **kwargs, + ) -> None: + """ + Creates a new instance of Authorization. + + :param storage: The storage system to use for state management. + :type storage: Storage + :param connection_manager: The connection manager for OAuth providers. + :type connection_manager: Connections + :param auth_handlers: Configuration for OAuth providers. + :type auth_handlers: dict[str, AuthHandler], optional + :raises ValueError: When storage is None or no auth handlers provided. + """ + if not storage: + raise ValueError("Storage is required for Authorization") + if not auth_handler and not auth_handler_settings: + raise ValueError( + "At least one of auth_handler or auth_handler_settings is required." + ) + + self._storage = storage + self._connection_manager = connection_manager + + if auth_handler: + self._handler = auth_handler + else: + self._handler = AuthHandler._from_settings(auth_handler_settings) + + self._id = auth_handler_id or self._handler.name + if not self._id: + raise ValueError( + "Auth handler must have an ID. Could not be deduced from settings or constructor args." + ) + + async def _sign_in( + self, context: TurnContext, scopes: Optional[list[str]] = None + ) -> _SignInResponse: + """Initiate or continue the sign-in process for the user with the given auth handler. + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param scopes: Optional list of scopes to request during sign-in. If None, default scopes will be used. + :type scopes: Optional[list[str]], optional + :return: A SignInResponse indicating the result of the sign-in attempt. + :rtype: SignInResponse + """ + raise NotImplementedError() + + async def get_refreshed_token( + self, + context: TurnContext, + exchange_connection: Optional[str] = None, + exchange_scopes: Optional[list[str]] = None, + ) -> TokenResponse: + """Attempts to get a refreshed token for the user with the given scopes + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param exchange_connection: Optional name of the connection to use for token exchange. If None, default connection will be used. + :type exchange_connection: Optional[str], optional + :param exchange_scopes: Optional list of scopes to request during token exchange. If None, default scopes will be used. + :type exchange_scopes: Optional[list[str]], optional + """ + raise NotImplementedError() + + async def _sign_out(self, context: TurnContext) -> None: + """Attempts to sign out the user from the specified auth handler or all handlers if none specified. + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param auth_handler_id: The ID of the auth handler to sign out from. If None, sign out from all handlers. + :type auth_handler_id: Optional[str] + """ + raise NotImplementedError() diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/_user_authorization.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/_user_authorization.py new file mode 100644 index 00000000..1083d240 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/_user_authorization.py @@ -0,0 +1,261 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations +import logging +import jwt +from typing import Optional + +from microsoft_agents.activity import ( + Attachment, + ActionTypes, + CardAction, + OAuthCard, + TokenResponse, +) + +from microsoft_agents.hosting.core.card_factory import CardFactory +from microsoft_agents.hosting.core.message_factory import MessageFactory +from microsoft_agents.hosting.core.connector.client import UserTokenClient +from microsoft_agents.hosting.core.turn_context import TurnContext +from microsoft_agents.hosting.core._oauth import ( + _OAuthFlow, + _FlowResponse, + _FlowState, + _FlowStorageClient, + _FlowStateTag, +) +from .._sign_in_response import _SignInResponse +from ._authorization_handler import _AuthorizationHandler + +logger = logging.getLogger(__name__) + + +class _UserAuthorization(_AuthorizationHandler): + """ + Class responsible for managing authorization and OAuth flows. + Handles multiple OAuth providers and manages the complete authentication lifecycle. + """ + + async def _load_flow( + self, context: TurnContext + ) -> tuple[_OAuthFlow, _FlowStorageClient]: + """Loads the OAuth flow for a specific auth handler. + + A new flow is created in Storage if none exists for the channel, user, and handler + combination. + + :param context: The context object for the current turn. + :type context: TurnContext + :param auth_handler_id: The ID of the auth handler to use. + :type auth_handler_id: str + :return: A tuple containing the OAuthFlow and FlowStorageClient created from the + context and the specified auth handler. + :rtype: tuple[OAuthFlow, FlowStorageClient] + """ + user_token_client: UserTokenClient = context.turn_state.get( + context.adapter.USER_TOKEN_CLIENT_KEY + ) + + if ( + not context.activity.channel_id + or not context.activity.from_property + or not context.activity.from_property.id + ): + raise ValueError("Channel ID and User ID are required") + + channel_id = context.activity.channel_id + user_id = context.activity.from_property.id + + ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims[ + "aud" + ] + + # try to load existing state + flow_storage_client = _FlowStorageClient(channel_id, user_id, self._storage) + logger.info("Loading OAuth flow state from storage") + flow_state: _FlowState = await flow_storage_client.read(self._id) + if not flow_state: + logger.info("No existing flow state found, creating new flow state") + flow_state = _FlowState( + channel_id=channel_id, + user_id=user_id, + auth_handler_id=self._id, + connection=self._handler.abs_oauth_connection_name, + ms_app_id=ms_app_id, + ) + # await flow_storage_client.write(flow_state) + + flow = _OAuthFlow(flow_state, user_token_client) + return flow, flow_storage_client + + async def _handle_obo( + self, + context: TurnContext, + input_token_response: TokenResponse, + exchange_connection: Optional[str] = None, + exchange_scopes: Optional[list[str]] = None, + ) -> TokenResponse: + """ + Exchanges a token for another token with different scopes. + + :param context: The context object for the current turn. + :type context: TurnContext + :param scopes: The scopes to request for the new token. + :type scopes: list[str] + :param auth_handler_id: Optional ID of the auth handler to use, defaults to first + :type auth_handler_id: str + :return: The token response from the OAuth provider from the exchange. + If the cached token is not exchangeable, returns the cached token. + :rtype: TokenResponse + """ + if not input_token_response: + return input_token_response + + token = input_token_response.token + + connection_name = exchange_connection or self._handler.obo_connection_name + exchange_scopes = exchange_scopes or self._handler.scopes + + if not connection_name or not exchange_scopes: + return input_token_response + + if not input_token_response.is_exchangeable(): + return input_token_response + + token_provider = self._connection_manager.get_connection(connection_name) + if not token_provider: + raise ValueError(f"Connection '{connection_name}' not found") + + token = await token_provider.acquire_token_on_behalf_of( + scopes=exchange_scopes, + user_assertion=input_token_response.token, + ) + return TokenResponse(token=token) if token else TokenResponse() + + async def _sign_out( + self, + context: TurnContext, + ) -> None: + """ + _Signs out the current user. + This method clears the user's token and resets the OAuth state. + + :param context: The context object for the current turn. + :param auth_handler_id: Optional ID of the auth handler to use for sign out. If None, + signs out from all the handlers. + """ + flow, flow_storage_client = await self._load_flow(context) + logger.info("Signing out from handler: %s", self._id) + await flow.sign_out() + await flow_storage_client.delete(self._id) + + async def _handle_flow_response( + self, context: TurnContext, flow_response: _FlowResponse + ) -> None: + """Handles CONTINUE and FAILURE flow responses, sending activities back.""" + flow_state: _FlowState = flow_response.flow_state + + if flow_state.tag == _FlowStateTag.BEGIN: + # Create the OAuth card + sign_in_resource = flow_response.sign_in_resource + assert sign_in_resource + o_card: Attachment = CardFactory.oauth_card( + OAuthCard( + text="Sign in", + connection_name=flow_state.connection, + buttons=[ + CardAction( + title="Sign in", + type=ActionTypes.signin, + value=sign_in_resource.sign_in_link, + channel_data=None, + ) + ], + token_exchange_resource=sign_in_resource.token_exchange_resource, + token_post_resource=sign_in_resource.token_post_resource, + ) + ) + # Send the card to the user + await context.send_activity(MessageFactory.attachment(o_card)) + elif flow_state.tag == _FlowStateTag.FAILURE: + if flow_state.reached_max_attempts(): + await context.send_activity( + MessageFactory.text( + "Sign-in failed. Max retries reached. Please try again later." + ) + ) + elif flow_state.is_expired(): + await context.send_activity( + MessageFactory.text("Sign-in session expired. Please try again.") + ) + else: + logger.warning("Sign-in flow failed for unknown reasons.") + await context.send_activity("Sign-in failed. Please try again.") + + async def _sign_in( + self, + context: TurnContext, + exchange_connection: Optional[str] = None, + exchange_scopes: Optional[list[str]] = None, + ) -> _SignInResponse: + """Begins or continues an OAuth flow. + + Handles the flow response, sending the OAuth card to the context. + + :param context: The context object for the current turn. + :type context: TurnContext + :param auth_handler_id: The ID of the auth handler to use. + :type auth_handler_id: str + :return: The _SignInResponse containing the token response and flow state tag. + :rtype: _SignInResponse + """ + flow, flow_storage_client = await self._load_flow(context) + flow_response: _FlowResponse = await flow.begin_or_continue_flow( + context.activity + ) + + logger.info("Saving OAuth flow state to storage") + await flow_storage_client.write(flow_response.flow_state) + await self._handle_flow_response(context, flow_response) + + if flow_response.token_response: + # attempt exchange if needed + # if not needed, returns the same token + token_response = await self._handle_obo( + context, + flow_response.token_response, + exchange_connection, + exchange_scopes, + ) + + return _SignInResponse( + token_response=token_response, + tag=_FlowStateTag.COMPLETE if token_response else _FlowStateTag.FAILURE, + ) + + return _SignInResponse(tag=flow_response.flow_state.tag) + + async def get_refreshed_token( + self, + context: TurnContext, + exchange_connection: Optional[str] = None, + exchange_scopes: Optional[list[str]] = None, + ) -> TokenResponse: + """Attempts to get a refreshed token for the user with the given scopes + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param exchange_connection: Optional name of the connection to use for token exchange. If None, default connection will be used. + :type exchange_connection: Optional[str], optional + :param exchange_scopes: Optional list of scopes to request during token exchange. If None, default scopes will be used. + :type exchange_scopes: Optional[list[str]], optional + """ + flow, _ = await self._load_flow(context) + input_token_response = await flow.get_user_token() + return await self._handle_obo( + context, + input_token_response, + exchange_connection, + exchange_scopes, + ) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/agentic_user_authorization.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/agentic_user_authorization.py new file mode 100644 index 00000000..133d8145 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_handlers/agentic_user_authorization.py @@ -0,0 +1,181 @@ +import logging + +from typing import Optional + +from microsoft_agents.activity import TokenResponse + +from ....turn_context import TurnContext +from ...._oauth import _FlowStateTag +from .._sign_in_response import _SignInResponse +from ._authorization_handler import _AuthorizationHandler +from ....storage import Storage +from ....authorization import Connections +from ..auth_handler import AuthHandler + +logger = logging.getLogger(__name__) + + +class AgenticUserAuthorization(_AuthorizationHandler): + """Class responsible for managing agentic authorization""" + + def __init__( + self, + storage: Storage, + connection_manager: Connections, + auth_handler: Optional[AuthHandler] = None, + *, + auth_handler_id: Optional[str] = None, + auth_handler_settings: Optional[dict] = None, + **kwargs, + ) -> None: + """ + Creates a new instance of Authorization. + + :param storage: The storage system to use for state management. + :type storage: Storage + :param connection_manager: The connection manager for OAuth providers. + :type connection_manager: Connections + :param auth_handlers: Configuration for OAuth providers. + :type auth_handlers: dict[str, AuthHandler], optional + :raises ValueError: When storage is None or no auth handlers provided. + """ + super().__init__( + storage, + connection_manager, + auth_handler, + auth_handler_id=auth_handler_id, + auth_handler_settings=auth_handler_settings, + **kwargs, + ) + self._alt_blueprint_name = ( + auth_handler._alt_blueprint_name if auth_handler else None + ) + + async def get_agentic_instance_token(self, context: TurnContext) -> TokenResponse: + """Gets the agentic instance token for the current agent instance. + + :param context: The context object for the current turn. + :type context: TurnContext + :return: The agentic instance token, or None if not an agentic request. + :rtype: Optional[str] + """ + + if not context.activity.is_agentic_request(): + return TokenResponse() + + assert context.identity + connection = self._connection_manager.get_token_provider( + context.identity, "agentic" + ) + agentic_instance_id = context.activity.get_agentic_instance_id() + assert agentic_instance_id + instance_token, _ = await connection.get_agentic_instance_token( + agentic_instance_id + ) + return ( + TokenResponse(token=instance_token) if instance_token else TokenResponse() + ) + + async def get_agentic_user_token( + self, context: TurnContext, scopes: list[str] + ) -> TokenResponse: + """Gets the agentic user token for the current agent instance and user. + + :param context: The context object for the current turn. + :type context: TurnContext + :param scopes: The scopes to request for the token. + :type scopes: list[str] + :return: The agentic user token, or None if not an agentic request or no user. + :rtype: Optional[str] + """ + logger.info("Retrieving agentic user token for scopes: %s", scopes) + + if ( + not context.activity.is_agentic_request() + or not context.activity.get_agentic_user() + ): + return TokenResponse() + + assert context.identity + if self._alt_blueprint_name: + logger.debug( + "Using alternative blueprint name for agentic user token retrieval: %s", + self._alt_blueprint_name, + ) + connection = self._connection_manager.get_connection( + self._alt_blueprint_name + ) + else: + logger.debug( + "Using connection manager for agentic user token retrieval with handler id: %s", + self._id, + ) + connection = self._connection_manager.get_token_provider( + context.identity, "agentic" + ) + upn = context.activity.get_agentic_user() + agentic_instance_id = context.activity.get_agentic_instance_id() + if not upn or not agentic_instance_id: + logger.error( + "Unable to retrieve agentic user token: missing UPN or agentic instance ID. UPN: %s, Agentic Instance ID: %s", + upn, + agentic_instance_id, + ) + raise ValueError( + f"Unable to retrieve agentic user token: missing UPN or agentic instance ID. UPN: {upn}, Agentic Instance ID: {agentic_instance_id}" + ) + + token = await connection.get_agentic_user_token( + agentic_instance_id, upn, scopes + ) + return TokenResponse(token=token) if token else TokenResponse() + + async def _sign_in( + self, + context: TurnContext, + exchange_connection: Optional[str] = None, + exchange_scopes: Optional[list[str]] = None, + ) -> _SignInResponse: + """Retrieves the agentic user token if available. + + :param context: The context object for the current turn. + :type context: TurnContext + :param connection_name: The name of the connection to use for sign-in. + :type connection_name: str + :param scopes: The scopes to request for the token. + :type scopes: Optional[list[str]] + :return: A _SignInResponse containing the token response and flow state tag. + :rtype: _SignInResponse + """ + token_response = await self.get_refreshed_token( + context, exchange_connection, exchange_scopes + ) + if token_response: + return _SignInResponse( + token_response=token_response, tag=_FlowStateTag.COMPLETE + ) + return _SignInResponse(tag=_FlowStateTag.FAILURE) + + async def get_refreshed_token( + self, + context: TurnContext, + exchange_connection: Optional[str] = None, + exchange_scopes: Optional[list[str]] = None, + ) -> TokenResponse: + """Attempts to get a refreshed token for the user with the given scopes + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param exchange_connection: Optional name of the connection to use for token exchange. If None, default connection will be used. + :type exchange_connection: Optional[str], optional + :param exchange_scopes: Optional list of scopes to request during token exchange. If None, default scopes will be used. + :type exchange_scopes: Optional[list[str]], optional + """ + if not exchange_scopes: + exchange_scopes = self._handler.scopes or [] + return await self.get_agentic_user_token(context, exchange_scopes) + + async def sign_out( + self, context: TurnContext, auth_handler_id: Optional[str] = None + ) -> None: + """Nothing to do for agentic sign out.""" diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_response.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_response.py new file mode 100644 index 00000000..4c2968da --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_response.py @@ -0,0 +1,24 @@ +from typing import Optional + +from microsoft_agents.activity import TokenResponse + +from ..._oauth import _FlowStateTag + + +class _SignInResponse: + """Response for a sign-in attempt, including the token response and flow state tag.""" + + token_response: TokenResponse + tag: _FlowStateTag + + def __init__( + self, + token_response: Optional[TokenResponse] = None, + tag: _FlowStateTag = _FlowStateTag.FAILURE, + ) -> None: + self.token_response = token_response or TokenResponse() + self.tag = tag + + def sign_in_complete(self) -> bool: + """Return True if the sign-in flow is complete (either successful or no attempt needed).""" + return self.tag in [_FlowStateTag.COMPLETE, _FlowStateTag.NOT_STARTED] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py new file mode 100644 index 00000000..9ade2c80 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/_sign_in_state.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Optional + +from microsoft_agents.activity import Activity + +from ...storage._type_aliases import JSON +from ...storage import StoreItem + + +class _SignInState(StoreItem): + """Store item for sign-in state, including tokens and continuation activity. + + Used to cache tokens and keep track of activities during single and + multi-turn sign-in flows. + """ + + def __init__( + self, + active_handler_id: str, + continuation_activity: Optional[Activity] = None, + ) -> None: + self.active_handler_id = active_handler_id + self.continuation_activity = continuation_activity + + def store_item_to_json(self) -> JSON: + return { + "active_handler_id": self.active_handler_id, + "continuation_activity": self.continuation_activity, + } + + @staticmethod + def from_json_to_store_item(json_data: JSON) -> _SignInState: + return _SignInState( + json_data["active_handler_id"], json_data.get("continuation_activity") + ) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/auth_handler.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/auth_handler.py index bce68789..8e298107 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/auth_handler.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/auth_handler.py @@ -2,47 +2,96 @@ # Licensed under the MIT License. import logging -from typing import Dict +from typing import Optional logger = logging.getLogger(__name__) +# name due to compat. +# see AuthorizationHandler for a class that does work. class AuthHandler: """ Interface defining an authorization handler for OAuth flows. """ + name: str + title: str + text: str + abs_oauth_connection_name: str + obo_connection_name: str + auth_type: str + scopes: list[str] + def __init__( self, - name: str = None, - title: str = None, - text: str = None, - abs_oauth_connection_name: str = None, - obo_connection_name: str = None, + name: str = "", + title: str = "", + text: str = "", + abs_oauth_connection_name: str = "", + obo_connection_name: str = "", + auth_type: str = "", + scopes: Optional[list[str]] = None, **kwargs, ): """ Initializes a new instance of AuthHandler. - Args: - name: The name of the OAuth connection. - auto: Whether to automatically start the OAuth flow. - title: Title for the OAuth card. - text: Text for the OAuth button. + :param name: The name of the handler. This is how it is accessed programatically + in this library. + :type name: str + :param title: Title for the OAuth card. + :type title: str + :param text: Text for the OAuth button. + :type text: str + :param abs_oauth_connection_name: The name of the Azure Bot Service OAuth connection. + :type abs_oauth_connection_name: str + :param obo_connection_name: The name of the On-Behalf-Of connection. + :type obo_connection_name: str + :param auth_type: The authorization variant used. This is likely to change in the future + to accept a class that implements AuthorizationVariant. + :type auth_type: str """ - self.name = name or kwargs.get("NAME") - self.title = title or kwargs.get("TITLE") - self.text = text or kwargs.get("TEXT") + self.name = name or kwargs.get("NAME", "") + self.title = title or kwargs.get("TITLE", "") + self.text = text or kwargs.get("TEXT", "") self.abs_oauth_connection_name = abs_oauth_connection_name or kwargs.get( - "AZUREBOTOAUTHCONNECTIONNAME" + "AZUREBOTOAUTHCONNECTIONNAME", "" ) self.obo_connection_name = obo_connection_name or kwargs.get( - "OBOCONNECTIONNAME" - ) - logger.debug( - f"AuthHandler initialized: name={self.name}, title={self.title}, text={self.text} abs_connection_name={self.abs_oauth_connection_name} obo_connection_name={self.obo_connection_name}" + "OBOCONNECTIONNAME", "" ) + self.auth_type = auth_type or kwargs.get("TYPE", "UserAuthorization") + self.auth_type = self.auth_type.lower() + if scopes: + self.scopes = list(scopes) + else: + self.scopes = AuthHandler._format_scopes(kwargs.get("SCOPES", "")) + self._alt_blueprint_name = kwargs.get("ALT_BLUEPRINT_NAME", None) + @staticmethod + def _format_scopes(scopes: str) -> list[str]: + lst = scopes.strip().split(" ") + return [s for s in lst if s] + + @staticmethod + def _from_settings(settings: dict): + """ + Creates an AuthHandler instance from a settings dictionary. -# # Type alias for authorization handlers dictionary -AuthorizationHandlers = Dict[str, AuthHandler] + :param settings: The settings dictionary containing configuration for the AuthHandler. + :type settings: dict + :return: An instance of AuthHandler configured with the provided settings. + :rtype: AuthHandler + """ + if not settings: + raise ValueError("Settings dictionary is required to create AuthHandler") + + return AuthHandler( + name=settings.get("NAME", ""), + title=settings.get("TITLE", ""), + text=settings.get("TEXT", ""), + abs_oauth_connection_name=settings.get("AZUREBOTOAUTHCONNECTIONNAME", ""), + obo_connection_name=settings.get("OBOCONNECTIONNAME", ""), + auth_type=settings.get("TYPE", ""), + scopes=AuthHandler._format_scopes(settings.get("SCOPES", "")), + ) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/authorization.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/authorization.py index 8ef635f0..e105ccf4 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/authorization.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/oauth/authorization.py @@ -1,53 +1,61 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from __future__ import annotations +from datetime import datetime import logging +from typing import TypeVar, Optional, Callable, Awaitable, Generic, cast import jwt -from typing import Dict, Optional, Callable, Awaitable, AsyncIterator -from collections.abc import Iterable -from contextlib import asynccontextmanager -from microsoft_agents.hosting.core.authorization import ( - Connections, - AccessTokenProviderBase, -) -from microsoft_agents.hosting.core.storage import Storage, MemoryStorage -from microsoft_agents.activity import TokenResponse -from microsoft_agents.hosting.core.connector.client import UserTokenClient +from microsoft_agents.activity import Activity, TokenResponse from ...turn_context import TurnContext -from ...oauth import OAuthFlow, FlowResponse, FlowState, FlowStateTag, FlowStorageClient -from ..state.turn_state import TurnState +from ...storage import Storage +from ...authorization import Connections +from ..._oauth import _FlowStateTag +from ..state import TurnState from .auth_handler import AuthHandler +from ._sign_in_state import _SignInState +from ._sign_in_response import _SignInResponse +from ._handlers import ( + AgenticUserAuthorization, + _UserAuthorization, + _AuthorizationHandler, +) logger = logging.getLogger(__name__) +AUTHORIZATION_TYPE_MAP = { + "userauthorization": _UserAuthorization, + "agenticuserauthorization": AgenticUserAuthorization, +} + class Authorization: - """ - Class responsible for managing authorization and OAuth flows. - Handles multiple OAuth providers and manages the complete authentication lifecycle. - """ + """Class responsible for managing authorization flows.""" + + _storage: Storage + _connection_manager: Connections + _handlers: dict[str, _AuthorizationHandler] def __init__( self, storage: Storage, connection_manager: Connections, - auth_handlers: dict[str, AuthHandler] = None, - auto_signin: bool = None, + auth_handlers: Optional[dict[str, AuthHandler]] = None, + auto_signin: bool = False, use_cache: bool = False, **kwargs, ): """ Creates a new instance of Authorization. - Args: - storage: The storage system to use for state management. - auth_handlers: Configuration for OAuth providers. + Handlers defined in the configuration (passed in via kwargs) will be used + only if auth_handlers is empty or None. - Raises: - ValueError: If storage is None or no auth handlers are provided. + :param storage: The storage system to use for state management. + :type storage: Storage + :param connection_manager: The connection manager for OAuth providers. + :type connection_manager: Connections + :param auth_handlers: Configuration for OAuth providers. + :type auth_handlers: dict[str, AuthHandler], optional + :raises ValueError: When storage is None or no auth handlers provided. """ if not storage: raise ValueError("Storage is required for Authorization") @@ -55,324 +63,305 @@ def __init__( self._storage = storage self._connection_manager = connection_manager - auth_configuration: Dict = kwargs.get("AGENTAPPLICATION", {}).get( - "USERAUTHORIZATION", {} - ) - - handlers_config: Dict[str, Dict] = auth_configuration.get("HANDLERS") - if not auth_handlers and handlers_config: - auth_handlers = { - handler_name: AuthHandler( - name=handler_name, **config.get("SETTINGS", {}) - ) - for handler_name, config in handlers_config.items() - } - - self._auth_handlers = auth_handlers or {} self._sign_in_success_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = lambda *args: None + ] = None self._sign_in_failure_handler: Optional[ Callable[[TurnContext, TurnState, Optional[str]], Awaitable[None]] - ] = lambda *args: None + ] = None - def _ids_from_context(self, context: TurnContext) -> tuple[str, str]: - """Checks and returns IDs necessary to load a new or existing flow. + self._handlers = {} - Raises a ValueError if channel ID or user ID are missing. - """ - if ( - not context.activity.channel_id - or not context.activity.from_property - or not context.activity.from_property.id - ): - raise ValueError("Channel ID and User ID are required") - - return context.activity.channel_id, context.activity.from_property.id - - async def _load_flow( - self, context: TurnContext, auth_handler_id: str = "" - ) -> tuple[OAuthFlow, FlowStorageClient]: - """Loads the OAuth flow for a specific auth handler. - - Args: - context: The context object for the current turn. - auth_handler_id: The ID of the auth handler to use. - - Returns: - The OAuthFlow returned corresponds to the flow associated with the - chosen handler, and the channel and user info found in the context. - The FlowStorageClient corresponds to the same channel and user info. - """ - user_token_client: UserTokenClient = context.turn_state.get( - context.adapter.USER_TOKEN_CLIENT_KEY - ) - - # resolve handler id - auth_handler: AuthHandler = self.resolve_handler(auth_handler_id) - auth_handler_id = auth_handler.name - - channel_id, user_id = self._ids_from_context(context) - - ms_app_id = context.turn_state.get(context.adapter.AGENT_IDENTITY_KEY).claims[ - "aud" - ] - - # try to load existing state - flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) - logger.info("Loading OAuth flow state from storage") - flow_state: FlowState = await flow_storage_client.read(auth_handler_id) - - if not flow_state: - logger.info("No existing flow state found, creating new flow state") - flow_state = FlowState( - channel_id=channel_id, - user_id=user_id, - auth_handler_id=auth_handler_id, - connection=auth_handler.abs_oauth_connection_name, - ms_app_id=ms_app_id, + if not auth_handlers: + # get from config + auth_configuration: dict = kwargs.get("AGENTAPPLICATION", {}).get( + "USERAUTHORIZATION", {} ) - await flow_storage_client.write(flow_state) - - flow = OAuthFlow(flow_state, user_token_client) - return flow, flow_storage_client - - @asynccontextmanager - async def open_flow( - self, context: TurnContext, auth_handler_id: str = "" - ) -> AsyncIterator[OAuthFlow]: - """Loads an OAuth flow and saves changes the changes to storage if any are made. - - Args: - context: The context object for the current turn. - auth_handler_id: ID of the auth handler to use. - If none provided, uses the first handler. - - Yields: - OAuthFlow: - The OAuthFlow instance loaded from storage or newly created - if not yet present in storage. + handlers_config: dict[str, dict] = auth_configuration.get("HANDLERS") + if not auth_handlers and handlers_config: + auth_handlers = { + handler_name: AuthHandler( + name=handler_name, **config.get("SETTINGS", {}) + ) + for handler_name, config in handlers_config.items() + } + + self._handler_settings = auth_handlers + + # operations default to the first handler if none specified + if self._handler_settings: + self._default_handler_id = next(iter(self._handler_settings.items()))[0] + self._init_handlers() + + def _init_handlers(self) -> None: + """Initialize authorization variants based on the provided auth handlers. + + This method maps the auth types to their corresponding authorization variants, and + it initializes an instance of each variant that is referenced. + + :param auth_handlers: A dictionary of auth handler configurations. + :type auth_handlers: dict[str, AuthHandler] """ - if not context: - logger.error("No context provided to open_flow") - raise ValueError("context is required") + for name, auth_handler in self._handler_settings.items(): + auth_type = auth_handler.auth_type + if auth_type not in AUTHORIZATION_TYPE_MAP: + raise ValueError(f"Auth type {auth_type} not recognized.") + + self._handlers[name] = AUTHORIZATION_TYPE_MAP[auth_type]( + storage=self._storage, + connection_manager=self._connection_manager, + auth_handler=auth_handler, + ) - flow, flow_storage_client = await self._load_flow(context, auth_handler_id) - yield flow - logger.info("Saving OAuth flow state to storage") - await flow_storage_client.write(flow.flow_state) + @staticmethod + def _sign_in_state_key(context: TurnContext) -> str: + """Generate a unique storage key for the sign-in state based on the context. - async def get_token( - self, context: TurnContext, auth_handler_id: str - ) -> TokenResponse: + This is the key used to store and retrieve the sign-in state from storage, and + can be used to inspect or manipulate the state directly if needed. + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :return: A unique (across other values of channel_id and user_id) key for the sign-in state. + :rtype: str """ - Gets the token for a specific auth handler. + return f"auth:_SignInState:{context.activity.channel_id}:{context.activity.from_property.id}" - Args: - context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. + async def _load_sign_in_state(self, context: TurnContext) -> Optional[_SignInState]: + """Load the sign-in state from storage for the given context.""" + key = self._sign_in_state_key(context) + return (await self._storage.read([key], target_cls=_SignInState)).get(key) - Returns: - The token response from the OAuth provider. + async def _save_sign_in_state( + self, context: TurnContext, state: _SignInState + ) -> None: + """Save the sign-in state to storage for the given context.""" + key = self._sign_in_state_key(context) + await self._storage.write({key: state}) + + async def _delete_sign_in_state(self, context: TurnContext) -> None: + """Delete the sign-in state from storage for the given context.""" + key = self._sign_in_state_key(context) + await self._storage.delete([key]) + + @staticmethod + def _cache_key(context: TurnContext, handler_id: str) -> str: + return f"{Authorization._sign_in_state_key(context)}:{handler_id}:token" + + @staticmethod + def _get_cached_token( + context: TurnContext, handler_id: str + ) -> Optional[TokenResponse]: + key = Authorization._cache_key(context, handler_id) + return cast(Optional[TokenResponse], context.turn_state.get(key)) + + @staticmethod + def _cache_token( + context: TurnContext, handler_id: str, token_response: TokenResponse + ) -> None: + key = Authorization._cache_key(context, handler_id) + context.turn_state[key] = token_response + + @staticmethod + def _delete_cached_token(context: TurnContext, handler_id: str) -> None: + key = Authorization._cache_key(context, handler_id) + if key in context.turn_state: + del context.turn_state[key] + + def _resolve_handler(self, handler_id: str) -> _AuthorizationHandler: + """Resolve the auth handler by its ID. + + :param handler_id: The ID of the auth handler to resolve. + :type handler_id: str + :return: The corresponding AuthorizationHandler instance. + :rtype: AuthorizationHandler + :raises ValueError: If the handler ID is not recognized or not configured. """ - logger.info("Getting token for auth handler: %s", auth_handler_id) - async with self.open_flow(context, auth_handler_id) as flow: - return await flow.get_user_token() + if handler_id not in self._handlers: + raise ValueError( + f"Auth handler {handler_id} not recognized or not configured." + ) + return self._handlers[handler_id] - async def exchange_token( + async def _start_or_continue_sign_in( self, context: TurnContext, - scopes: list[str], + state: TurnState, auth_handler_id: Optional[str] = None, - ) -> TokenResponse: + ) -> _SignInResponse: + """Start or continue the sign-in process for the user with the given auth handler. + + _SignInResponse output is based on the result of the variant used by the handler. + Storage is updated as needed with _SignInState data for caching purposes. + + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param state: The turn state for the current turn of conversation. + :type state: TurnState + :param auth_handler_id: The ID of the auth handler to use for sign-in. If None, the first handler will be used. + :type auth_handler_id: str + :return: A _SignInResponse indicating the result of the sign-in attempt. + :rtype: _SignInResponse """ - Exchanges a token for another token with different scopes. - - Args: - context: The context object for the current turn. - scopes: The scopes to request for the new token. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - Returns: - The token response from the OAuth provider. - """ - logger.info("Exchanging token for scopes: %s", scopes) - async with self.open_flow(context, auth_handler_id) as flow: - token_response = await flow.get_user_token() + auth_handler_id = auth_handler_id or self._default_handler_id - if token_response and self._is_exchangeable(token_response.token): - logger.debug("Token is exchangeable, performing OBO flow") - return await self._handle_obo(token_response.token, scopes, auth_handler_id) + # check cached sign in state + sign_in_state = await self._load_sign_in_state(context) + if not sign_in_state: + # no existing sign-in state, create a new one + sign_in_state = _SignInState(active_handler_id=auth_handler_id) - return TokenResponse() + auth_handler_id = sign_in_state.active_handler_id - def _is_exchangeable(self, token: str) -> bool: - """ - Checks if a token is exchangeable (has api:// audience). + handler = self._resolve_handler(auth_handler_id) - Args: - token: The token to check. + # attempt sign-in continuation (or beginning) + sign_in_response = await handler._sign_in(context) - Returns: - True if the token is exchangeable, False otherwise. - """ - try: - # Decode without verification to check the audience - payload = jwt.decode(token, options={"verify_signature": False}) - aud = payload.get("aud") - return isinstance(aud, str) and aud.startswith("api://") - except Exception: - logger.error("Failed to decode token to check audience") - return False - - async def _handle_obo( - self, token: str, scopes: list[str], handler_id: str = None - ) -> TokenResponse: - """ - Handles On-Behalf-Of token exchange. - - Args: - context: The context object for the current turn. - token: The original token. - scopes: The scopes to request. - - Returns: - The new token response. - - """ - auth_handler = self.resolve_handler(handler_id) - token_provider: AccessTokenProviderBase = ( - self._connection_manager.get_connection(auth_handler.obo_connection_name) - ) - - logger.info("Attempting to exchange token on behalf of user") - new_token = await token_provider.aquire_token_on_behalf_of( - scopes=scopes, - user_assertion=token, - ) - return TokenResponse( - token=new_token, - scopes=scopes, # Expiration can be set based on the token provider's response - ) - - async def get_active_flow_state(self, context: TurnContext) -> Optional[FlowState]: - """Gets the first active flow state for the current context.""" - logger.debug("Getting active flow state") - channel_id, user_id = self._ids_from_context(context) - flow_storage_client = FlowStorageClient(channel_id, user_id, self._storage) - for auth_handler_id in self._auth_handlers.keys(): - flow_state = await flow_storage_client.read(auth_handler_id) - if flow_state and flow_state.is_active(): - return flow_state - return None - - async def begin_or_continue_flow( - self, - context: TurnContext, - turn_state: TurnState, - auth_handler_id: str = "", - ) -> FlowResponse: - """Begins or continues an OAuth flow. - - Args: - context: The context object for the current turn. - turn_state: The state object for the current turn. - auth_handler_id: Optional ID of the auth handler to use, defaults to first handler. - - Returns: - The token response from the OAuth provider. - - """ - if not auth_handler_id: - auth_handler_id = self.resolve_handler().name - - logger.debug("Beginning or continuing OAuth flow") - async with self.open_flow(context, auth_handler_id) as flow: - prev_tag = flow.flow_state.tag - flow_response: FlowResponse = await flow.begin_or_continue_flow( - context.activity + if sign_in_response.tag == _FlowStateTag.COMPLETE: + if self._sign_in_success_handler: + await self._sign_in_success_handler(context, state, auth_handler_id) + await self._delete_sign_in_state(context) + Authorization._cache_token( + context, auth_handler_id, sign_in_response.token_response ) - flow_state: FlowState = flow_response.flow_state - - if ( - flow_state.tag == FlowStateTag.COMPLETE - and prev_tag != FlowStateTag.COMPLETE - ): - logger.debug("Calling Authorization sign in success handler") - self._sign_in_success_handler( - context, turn_state, flow_state.auth_handler_id - ) - elif flow_state.tag == FlowStateTag.FAILURE: - logger.debug("Calling Authorization sign in failure handler") - self._sign_in_failure_handler( - context, - turn_state, - flow_state.auth_handler_id, - flow_response.flow_error_tag, - ) + elif sign_in_response.tag == _FlowStateTag.FAILURE: + if self._sign_in_failure_handler: + await self._sign_in_failure_handler(context, state, auth_handler_id) + await self._delete_sign_in_state(context) - return flow_response + elif sign_in_response.tag in [_FlowStateTag.BEGIN, _FlowStateTag.CONTINUE]: + # store continuation activity and wait for next turn + sign_in_state.continuation_activity = context.activity + await self._save_sign_in_state(context, sign_in_state) - def resolve_handler(self, auth_handler_id: Optional[str] = None) -> AuthHandler: - """Resolves the auth handler to use based on the provided ID. + return sign_in_response - Args: - auth_handler_id: Optional ID of the auth handler to resolve, defaults to first handler. + async def sign_out( + self, context: TurnContext, auth_handler_id: Optional[str] = None + ) -> None: + """Attempts to sign out the user from a specified auth handler or the default handler. - Returns: - The resolved auth handler. + :param context: The turn context for the current turn of conversation. + :type context: TurnContext + :param auth_handler_id: The ID of the auth handler to sign out from. If None, sign out from all handlers. + :type auth_handler_id: Optional[str] + :return: None """ - if auth_handler_id: - if auth_handler_id not in self._auth_handlers: - logger.error("Auth handler '%s' not found", auth_handler_id) - raise ValueError(f"Auth handler '{auth_handler_id}' not found") - return self._auth_handlers[auth_handler_id] + auth_handler_id = auth_handler_id or self._default_handler_id + handler = self._resolve_handler(auth_handler_id) + Authorization._delete_cached_token(context, auth_handler_id) + await self._delete_sign_in_state(context) + await handler._sign_out(context) + + async def _on_turn_auth_intercept( + self, context: TurnContext, state: TurnState + ) -> tuple[bool, Optional[Activity]]: + """Intercepts the turn to check for active authentication flows. + + Returns true if the rest of the turn should be skipped because auth did not finish. + Returns false if the turn should continue processing as normal. + If auth completes and a new turn should be started, returns the continuation activity + from the cached _SignInState. + + :param context: The context object for the current turn. + :type context: TurnContext + :param state: The turn state for the current turn. + :type state: TurnState + :return: A tuple indicating whether the turn should be skipped and the continuation activity if applicable. + :rtype: tuple[bool, Optional[Activity]] + """ + sign_in_state = await self._load_sign_in_state(context) - # Return the first handler if no ID specified - return next(iter(self._auth_handlers.values())) + if sign_in_state: + auth_handler_id = sign_in_state.active_handler_id + if auth_handler_id: + sign_in_response = await self._start_or_continue_sign_in( + context, state, auth_handler_id + ) + if sign_in_response.tag == _FlowStateTag.COMPLETE: + assert sign_in_state.continuation_activity is not None + continuation_activity = ( + sign_in_state.continuation_activity.model_copy() + ) + # flow complete, start new turn with continuation activity + return True, continuation_activity + # auth flow still in progress, the turn should be skipped + return True, None + # no active auth flow, continue processing + return False, None - async def _sign_out( - self, - context: TurnContext, - auth_handler_ids: Iterable[str], - ) -> None: - """Signs out from the specified auth handlers. + async def get_token( + self, context: TurnContext, auth_handler_id: Optional[str] = None + ) -> TokenResponse: + """Gets the token for a specific auth handler or the default handler. - Args: - context: The context object for the current turn. - auth_handler_ids: Iterable of auth handler IDs to sign out from. + The token is taken from cache, so this does not initiate nor continue a sign-in flow. - Deletes the associated flow states from storage. + :param context: The context object for the current turn. + :type context: TurnContext + :param auth_handler_id: The ID of the auth handler to get the token for. + :type auth_handler_id: str + :return: The token response from the OAuth provider. + :rtype: TokenResponse """ - for auth_handler_id in auth_handler_ids: - flow, flow_storage_client = await self._load_flow(context, auth_handler_id) - # ensure that the id is valid - self.resolve_handler(auth_handler_id) - logger.info("Signing out from handler: %s", auth_handler_id) - await flow.sign_out() - await flow_storage_client.delete(auth_handler_id) + return await self.exchange_token(context, auth_handler_id=auth_handler_id) - async def sign_out( + async def exchange_token( self, context: TurnContext, + scopes: Optional[list[str]] = None, auth_handler_id: Optional[str] = None, - ) -> None: + exchange_connection: Optional[str] = None, + ) -> TokenResponse: + """Exchanges or refreshes the token for a specific auth handler or the default handler. + + :param context: The context object for the current turn. + :type context: TurnContext + :param scopes: The scopes to request during the token exchange or refresh. Defaults + to the list given in the AuthHandler configuration if None. + :type scopes: Optional[list[str]] + :param auth_handler_id: The ID of the auth handler to exchange or refresh the token for. + If None, the default handler will be used. + :type auth_handler_id: Optional[str] + :param exchange_connection: The name of the connection to use for token exchange. If None, + the connection defined in the AuthHandler configuration will be used. + :type exchange_connection: Optional[str] + :return: The token response from the OAuth provider. + :rtype: TokenResponse + :raises ValueError: If the specified auth handler ID is not recognized or not configured. """ - Signs out the current user. - This method clears the user's token and resets the OAuth state. - Args: - context: The context object for the current turn. - auth_handler_id: Optional ID of the auth handler to use for sign out. If None, - signs out from all the handlers. + auth_handler_id = auth_handler_id or self._default_handler_id + if auth_handler_id not in self._handlers: + raise ValueError( + f"Auth handler {auth_handler_id} not recognized or not configured." + ) - Deletes the associated flow state(s) from storage. - """ - if auth_handler_id: - await self._sign_out(context, [auth_handler_id]) - else: - await self._sign_out(context, self._auth_handlers.keys()) + cached_token = Authorization._get_cached_token(context, auth_handler_id) + + if cached_token: + + handler = self._resolve_handler(auth_handler_id) + + # TODO: for later -> parity with .NET + # token_res = sign_in_state.tokens[auth_handler_id] + # if not context.activity.is_agentic_request(): + # if token_res and not token_res.is_exchangeable(): + # token = token_res.token + # if token.expiration is not None: + # diff = token.expiration - datetime.now().timestamp() + # if diff > 0: + # return token_res.token + + res = await handler.get_refreshed_token( + context, exchange_connection, scopes + ) + if res: + return res + return TokenResponse() def on_sign_in_success( self, @@ -381,8 +370,7 @@ def on_sign_in_success( """ Sets a handler to be called when sign-in is successfully completed. - Args: - handler: The handler function to call on successful sign-in. + :param handler: The handler function to call on successful sign-in. """ self._sign_in_success_handler = handler @@ -392,7 +380,7 @@ def on_sign_in_failure( ) -> None: """ Sets a handler to be called when sign-in fails. - Args: - handler: The handler function to call on sign-in failure. + + :param handler: The handler function to call on sign-in failure. """ self._sign_in_failure_handler = handler diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/access_token_provider_base.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/access_token_provider_base.py index 3c413e61..e69647cd 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/access_token_provider_base.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/access_token_provider_base.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Protocol, Optional from abc import abstractmethod @@ -17,7 +17,7 @@ async def get_access_token( """ pass - async def aquire_token_on_behalf_of( + async def acquire_token_on_behalf_of( self, scopes: list[str], user_assertion: str ) -> str: """ @@ -28,3 +28,18 @@ async def aquire_token_on_behalf_of( :return: The access token as a string. """ raise NotImplementedError() + + async def get_agentic_application_token( + self, agent_app_instance_id: str + ) -> Optional[str]: + raise NotImplementedError() + + async def get_agentic_instance_token( + self, agent_app_instance_id: str + ) -> tuple[str, str]: + raise NotImplementedError() + + async def get_agentic_user_token( + self, agent_app_instance_id: str, upn: str, scopes: list[str] + ) -> Optional[str]: + raise NotImplementedError() diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/agent_auth_configuration.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/agent_auth_configuration.py index 763197e6..a6fee937 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/agent_auth_configuration.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/agent_auth_configuration.py @@ -6,6 +6,17 @@ class AgentAuthConfiguration: """ Configuration for Agent authentication. + + TENANT_ID: The tenant ID for the Azure AD. + CLIENT_ID: The client ID for the Azure AD application. + AUTH_TYPE: The type of authentication to use (microsoft_agents.hosting.core.authorization.auth_types.AuthTypes). + CLIENT_SECRET: The client secret for the Azure AD application (if using client secret authentication). + CERT_PEM_FILE: The path to the PEM file for certificate authentication (if using certificate authentication). + CERT_KEY_FILE: The path to the key file for certificate authentication (if using certificate authentication). + CONNECTION_NAME: The name of the connection + SCOPES: The scopes to request + AUTHORITY: The authority URL for the Azure AD (if different from the default).f + ALT_BLUEPRINT_ID: An optional alternative blueprint ID used when constructing a connector client. """ TENANT_ID: Optional[str] @@ -17,6 +28,7 @@ class AgentAuthConfiguration: CONNECTION_NAME: Optional[str] SCOPES: Optional[list[str]] AUTHORITY: Optional[str] + ALT_BLUEPRINT_ID: Optional[str] def __init__( self, @@ -31,6 +43,7 @@ def __init__( scopes: Optional[list[str]] = None, **kwargs: Optional[dict[str, str]], ): + self.AUTH_TYPE = auth_type or kwargs.get("AUTHTYPE", AuthTypes.client_secret) self.CLIENT_ID = client_id or kwargs.get("CLIENTID", None) self.AUTHORITY = authority or kwargs.get("AUTHORITY", None) @@ -40,6 +53,7 @@ def __init__( self.CERT_KEY_FILE = cert_key_file or kwargs.get("CERTKEYFILE", None) self.CONNECTION_NAME = connection_name or kwargs.get("CONNECTIONNAME", None) self.SCOPES = scopes or kwargs.get("SCOPES", None) + self.ALT_BLUEPRINT_ID = kwargs.get("ALT_BLUEPRINT_NAME", None) @property def ISSUERS(self) -> list[str]: diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/anonymous_token_provider.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/anonymous_token_provider.py index 318566a3..6ed36fcf 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/anonymous_token_provider.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/anonymous_token_provider.py @@ -1,3 +1,5 @@ +from typing import Optional + from .access_token_provider_base import AccessTokenProviderBase @@ -11,3 +13,23 @@ async def get_access_token( self, resource_url: str, scopes: list[str], force_refresh: bool = False ) -> str: return "" + + async def acquire_token_on_behalf_of( + self, scopes: list[str], user_assertion: str + ) -> str: + return "" + + async def get_agentic_application_token( + self, agent_app_instance_id: str + ) -> Optional[str]: + return "" + + async def get_agentic_instance_token( + self, agent_app_instance_id: str + ) -> tuple[str, str]: + return "", "" + + async def get_agentic_user_token( + self, agent_app_instance_id: str, upn: str, scopes: list[str] + ) -> Optional[str]: + return "" diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/authentication_constants.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/authentication_constants.py index c370ea72..296a8df2 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/authentication_constants.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/authentication_constants.py @@ -100,3 +100,11 @@ class AuthenticationConstants(ABC): # Tenant Id claim name. As used in Microsoft AAD tokens. TENANT_ID_CLAIM = "tid" + + APX_LOCAL_SCOPE = "c16e153d-5d2b-4c21-b7f4-b05ee5d516f1/.default" + APX_DEV_SCOPE = "0d94caae-b412-4943-8a68-83135ad6d35f/.default" + APX_PRODUCTION_SCOPE = "5a807f24-c9de-44ee-a3a7-329e88a00ffc/.default" + APX_GCC_SCOPE = "c9475445-9789-4fef-9ec5-cde4a9bcd446/.default" + APX_GCCH_SCOPE = "6f669b9e-7701-4e2b-b624-82c9207fde26/.default" + APX_DOD_SCOPE = "0a069c81-8c7c-4712-886b-9c542d673ffb/.default" + APX_GALLATIN_SCOPE = "bd004c8e-5acf-4c48-8570-4e7d46b2f63b/.default" diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/claims_identity.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/claims_identity.py index a8a92ebb..af30b409 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/claims_identity.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/claims_identity.py @@ -10,7 +10,7 @@ def __init__( self, claims: dict[str, str], is_authenticated: bool, - authentication_type: str = None, + authentication_type: Optional[str] = None, ): self.claims = claims self.is_authenticated = is_authenticated diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py index 26bfc7ee..714199b5 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py @@ -45,7 +45,6 @@ def _get_public_key_or_secret(self, token: str) -> PyJWK: if unverified_payload.get("iss") == "https://api.botframework.com" else f"https://login.microsoftonline.com/{self.configuration.TENANT_ID}/discovery/v2.0/keys" ) - jwks_client = PyJWKClient(jwksUri) key = jwks_client.get_signing_key(header["kid"]) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py index 6c325c17..0fcec2cc 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py @@ -7,7 +7,7 @@ from abc import ABC from copy import Error from http import HTTPStatus -from typing import Awaitable, Callable, cast +from typing import Awaitable, Callable, cast, Optional from uuid import uuid4 from microsoft_agents.activity import ( @@ -213,12 +213,28 @@ async def create_conversation( # pylint: disable=arguments-differ claims_identity = self.create_claims_identity(agent_app_id) claims_identity.claims[AuthenticationConstants.SERVICE_URL_CLAIM] = service_url + # Create a UserTokenClient instance for the application to use. (For example, in the OAuthPrompt.) + user_token_client: UserTokenClient = ( + await self._channel_service_client_factory.create_user_token_client( + claims_identity + ) + ) + + # Create a turn context and run the pipeline. + context = self._create_turn_context( + claims_identity, + None, + user_token_client, + callback, + ) + # Create the connector client to use for outbound requests. connector_client: ConnectorClient = ( await self._channel_service_client_factory.create_connector_client( - claims_identity, service_url, audience + context, claims_identity, service_url, audience ) ) + context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client # Make the actual create conversation call using the connector. create_conversation_result = ( @@ -232,22 +248,7 @@ async def create_conversation( # pylint: disable=arguments-differ create_conversation_result, channel_id, service_url, conversation_parameters ) - # Create a UserTokenClient instance for the application to use. (For example, in the OAuthPrompt.) - user_token_client: UserTokenClient = ( - await self._channel_service_client_factory.create_user_token_client( - claims_identity - ) - ) - - # Create a turn context and run the pipeline. - context = self._create_turn_context( - create_activity, - claims_identity, - None, - connector_client, - user_token_client, - callback, - ) + context.activity = create_activity # Run the pipeline await self.run_pipeline(context, callback) @@ -262,12 +263,6 @@ async def process_proactive( audience: str, callback: Callable[[TurnContext], Awaitable], ): - # Create the connector client to use for outbound requests. - connector_client: ConnectorClient = ( - await self._channel_service_client_factory.create_connector_client( - claims_identity, continuation_activity.service_url, audience - ) - ) # Create a UserTokenClient instance for the application to use. (For example, in the OAuthPrompt.) user_token_client: UserTokenClient = ( @@ -278,14 +273,21 @@ async def process_proactive( # Create a turn context and run the pipeline. context = self._create_turn_context( - continuation_activity, claims_identity, audience, - connector_client, user_token_client, callback, + activity=continuation_activity, ) + # Create the connector client to use for outbound requests. + connector_client: ConnectorClient = ( + await self._channel_service_client_factory.create_connector_client( + context, claims_identity, continuation_activity.service_url, audience + ) + ) + context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client + # Run the pipeline await self.run_pipeline(context, callback) @@ -336,17 +338,6 @@ async def process_activity( ): use_anonymous_auth_callback = True - # Create the connector client to use for outbound requests. - connector_client: ConnectorClient = ( - await self._channel_service_client_factory.create_connector_client( - claims_identity, - activity.service_url, - outgoing_audience, - scopes, - use_anonymous_auth_callback, - ) - ) - # Create a UserTokenClient instance for the OAuth flow. user_token_client: UserTokenClient = ( await self._channel_service_client_factory.create_user_token_client( @@ -356,14 +347,26 @@ async def process_activity( # Create a turn context and run the pipeline. context = self._create_turn_context( - activity, claims_identity, outgoing_audience, - connector_client, user_token_client, callback, + activity=activity, ) + # Create the connector client to use for outbound requests. + connector_client: ConnectorClient = ( + await self._channel_service_client_factory.create_connector_client( + context, + claims_identity, + activity.service_url, + outgoing_audience, + scopes, + use_anonymous_auth_callback, + ) + ) + context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client + await self.run_pipeline(context, callback) await connector_client.close() @@ -420,17 +423,15 @@ def _create_create_activity( def _create_turn_context( self, - activity: Activity, claims_identity: ClaimsIdentity, oauth_scope: str, - connector_client: ConnectorClientBase, user_token_client: UserTokenClientBase, callback: Callable[[TurnContext], Awaitable], + activity: Optional[Activity] = None, ) -> TurnContext: - context = TurnContext(self, activity) + context = TurnContext(self, activity, claims_identity) context.turn_state[self.AGENT_IDENTITY_KEY] = claims_identity - context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client context.turn_state[self.USER_TOKEN_CLIENT_KEY] = user_token_client context.turn_state[self.AGENT_CALLBACK_HANDLER_KEY] = callback context.turn_state[self.CHANNEL_SERVICE_FACTORY_KEY] = ( diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/__init__.py deleted file mode 100644 index 79858343..00000000 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .flow_state import FlowState, FlowStateTag, FlowErrorTag -from .flow_storage_client import FlowStorageClient -from .oauth_flow import OAuthFlow, FlowResponse - -__all__ = [ - "FlowState", - "FlowStateTag", - "FlowErrorTag", - "FlowResponse", - "FlowStorageClient", - "OAuthFlow", -] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/rest_channel_service_client_factory.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/rest_channel_service_client_factory.py index af280654..7c444d89 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/rest_channel_service_client_factory.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/rest_channel_service_client_factory.py @@ -1,5 +1,8 @@ +import re from typing import Optional +import logging +from microsoft_agents.activity import RoleTypes from microsoft_agents.hosting.core.authorization import ( AuthenticationConstants, AnonymousTokenProvider, @@ -12,6 +15,9 @@ from microsoft_agents.hosting.core.connector.teams import TeamsConnectorClient from .channel_service_client_factory_base import ChannelServiceClientFactoryBase +from .turn_context import TurnContext + +logger = logging.getLogger(__name__) class RestChannelServiceClientFactory(ChannelServiceClientFactoryBase): @@ -29,6 +35,7 @@ def __init__( async def create_connector_client( self, + context: TurnContext, claims_identity: ClaimsIdentity, service_url: str, audience: str, @@ -44,15 +51,63 @@ async def create_connector_client( "RestChannelServiceClientFactory.create_connector_client: audience can't be None or Empty" ) - token_provider: AccessTokenProviderBase = ( - self._connection_manager.get_token_provider(claims_identity, service_url) - if not use_anonymous - else self._ANONYMOUS_TOKEN_PROVIDER - ) + if context.activity.is_agentic_request(): + logger.info( + "Creating connector client for agentic request to service_url: %s", + service_url, + ) - token = await token_provider.get_access_token( - audience, scopes or [f"{audience}/.default"] - ) + if not context.identity: + raise ValueError("context.identity is required for agentic activities") + + connection = self._connection_manager.get_token_provider( + context.identity, service_url + ) + + # TODO: clean up linter + if connection._msal_configuration.ALT_BLUEPRINT_ID: + logger.debug( + "Using alternative blueprint ID for agentic token retrieval: %s", + connection._msal_configuration.ALT_BLUEPRINT_ID, + ) + connection = self._connection_manager.get_connection( + connection._msal_configuration.ALT_BLUEPRINT_ID + ) + + agent_instance_id = context.activity.get_agentic_instance_id() + if not agent_instance_id: + raise ValueError( + "Agent instance ID is required for agentic identity role" + ) + + if context.activity.recipient.role == RoleTypes.agentic_identity: + token, _ = await connection.get_agentic_instance_token( + agent_instance_id + ) + else: + agentic_user = context.activity.get_agentic_user() + if not agentic_user: + raise ValueError("Agentic user is required for agentic user role") + token = await connection.get_agentic_user_token( + agent_instance_id, + agentic_user, + [AuthenticationConstants.APX_PRODUCTION_SCOPE], + ) + + if not token: + raise ValueError("Failed to obtain token for agentic activity") + else: + token_provider: AccessTokenProviderBase = ( + self._connection_manager.get_token_provider( + claims_identity, service_url + ) + if not use_anonymous + else self._ANONYMOUS_TOKEN_PROVIDER + ) + + token = await token_provider.get_access_token( + audience, scopes or [f"{audience}/.default"] + ) return TeamsConnectorClient( endpoint=service_url, @@ -62,12 +117,11 @@ async def create_connector_client( async def create_user_token_client( self, claims_identity: ClaimsIdentity, use_anonymous: bool = False ) -> UserTokenClient: - token_provider = ( - self._connection_manager.get_token_provider( - claims_identity, self._token_service_endpoint - ) - if not use_anonymous - else self._ANONYMOUS_TOKEN_PROVIDER + if use_anonymous: + return UserTokenClient(endpoint=self._token_service_endpoint, token="") + + token_provider = self._connection_manager.get_token_provider( + claims_identity, self._token_service_endpoint ) token = await token_provider.get_access_token( diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py index 70e022a4..f39e8428 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py @@ -3,6 +3,7 @@ from __future__ import annotations import re +from typing import Optional from copy import copy, deepcopy from collections.abc import Callable @@ -17,13 +18,19 @@ ResourceResponse, DeliveryModes, ) +from microsoft_agents.hosting.core.authorization.claims_identity import ClaimsIdentity class TurnContext(TurnContextProtocol): # Same constant as in the BF Adapter, duplicating here to avoid circular dependency _INVOKE_RESPONSE_KEY = "TurnContext.InvokeResponse" - def __init__(self, adapter_or_context, request: Activity = None): + def __init__( + self, + adapter_or_context, + request: Activity = None, + identity: ClaimsIdentity = None, + ): """ Creates a new TurnContext instance. :param adapter_or_context: @@ -31,6 +38,7 @@ def __init__(self, adapter_or_context, request: Activity = None): """ if isinstance(adapter_or_context, TurnContext): adapter_or_context.copy_to(self) + self._identity = adapter_or_context.identity else: self.adapter = adapter_or_context self._activity = request @@ -46,6 +54,7 @@ def __init__(self, adapter_or_context, request: Activity = None): ["TurnContext", ConversationReference, Callable], None ] = [] self._responded: bool = False + self._identity = identity if self.adapter is None: raise TypeError("TurnContext must be instantiated with an adapter.") @@ -143,6 +152,10 @@ def streaming_response(self): self._streaming_response = None return self._streaming_response + @property + def identity(self) -> Optional[ClaimsIdentity]: + return self._identity + def get(self, key: str) -> object: if not key or not isinstance(key, str): raise TypeError('"key" must be a valid string.') diff --git a/tests/_common/__init__.py b/tests/_common/__init__.py index bb8ba4f3..bc07d46d 100644 --- a/tests/_common/__init__.py +++ b/tests/_common/__init__.py @@ -1,5 +1,7 @@ from .approx_equal import approx_eq +from .create_env_var_dict import create_env_var_dict __all__ = [ "approx_eq", + "create_env_var_dict", ] diff --git a/tests/_common/create_env_var_dict.py b/tests/_common/create_env_var_dict.py new file mode 100644 index 00000000..3e924bd5 --- /dev/null +++ b/tests/_common/create_env_var_dict.py @@ -0,0 +1,10 @@ +def create_env_var_dict(env_raw: str) -> dict[str, str]: + """Create a dictionary from a string that represents a .env config file.""" + lines = env_raw.strip().split("\n") + env = {} + for line in lines: + if not line.strip(): + continue + key, value = line.split("=", 1) + env[key.strip()] = value.strip() + return env diff --git a/tests/_common/data/__init__.py b/tests/_common/data/__init__.py index 11754a85..0d43733e 100644 --- a/tests/_common/data/__init__.py +++ b/tests/_common/data/__init__.py @@ -5,6 +5,8 @@ ) from .test_storage_data import TEST_STORAGE_DATA from .test_flow_data import TEST_FLOW_DATA +from .configs import TEST_ENV_DICT, TEST_ENV +from .configs import TEST_AGENTIC_ENV_DICT, TEST_AGENTIC_ENV __all__ = [ "TEST_DEFAULTS", @@ -12,4 +14,8 @@ "TEST_STORAGE_DATA", "TEST_FLOW_DATA", "create_test_auth_handler", + "TEST_ENV_DICT", + "TEST_ENV", + "TEST_AGENTIC_ENV_DICT", + "TEST_AGENTIC_ENV", ] diff --git a/tests/_common/data/configs/__init__.py b/tests/_common/data/configs/__init__.py new file mode 100644 index 00000000..450fb5c6 --- /dev/null +++ b/tests/_common/data/configs/__init__.py @@ -0,0 +1,4 @@ +from .test_auth_config import TEST_ENV_DICT, TEST_ENV +from .test_agentic_auth_config import TEST_AGENTIC_ENV_DICT, TEST_AGENTIC_ENV + +__all__ = ["TEST_ENV_DICT", "TEST_ENV", "TEST_AGENTIC_ENV_DICT", "TEST_AGENTIC_ENV"] diff --git a/tests/_common/data/configs/test_agentic_auth_config.py b/tests/_common/data/configs/test_agentic_auth_config.py new file mode 100644 index 00000000..9de0f8bd --- /dev/null +++ b/tests/_common/data/configs/test_agentic_auth_config.py @@ -0,0 +1,54 @@ +from microsoft_agents.activity import load_configuration_from_env + +from ...create_env_var_dict import create_env_var_dict +from ..test_defaults import TEST_DEFAULTS + +DEFAULTS = TEST_DEFAULTS() + +_TEST_AGENTIC_ENV_RAW = """ +CONNECTIONS__SERVICE_CONNECTION__SETTINGS__TENANTID=service-tenant-id +CONNECTIONS__SERVICE_CONNECTION__SETTINGS__CLIENTID=service-client-id +CONNECTIONS__SERVICE_CONNECTION__SETTINGS__CLIENTSECRET=service-client-secret + +CONNECTIONS__AGENTIC__SETTINGS__TENANTID=service-tenant-id +CONNECTIONS__AGENTIC__SETTINGS__CLIENTID=service-client-id +CONNECTIONS__AGENTIC__SETTINGS__CLIENTSECRET=service-client-secret + +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__AZUREBOTOAUTHCONNECTIONNAME={abs_oauth_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__OBOCONNECTIONNAME={obo_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TITLE={auth_handler_title} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TEXT={auth_handler_text} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TYPE=UserAuthorization +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__SCOPES=scope1 scope2 + +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__AZUREBOTOAUTHCONNECTIONNAME={agentic_abs_oauth_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__OBOCONNECTIONNAME={agentic_obo_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__TITLE={agentic_auth_handler_title} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__TEXT={agentic_auth_handler_text} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__TYPE=AgenticUserAuthorization +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__SCOPES=user.Read Mail.Read + +CONNECTIONSMAP__0__CONNECTION=SERVICE_CONNECTION +CONNECTIONSMAP__0__SERVICEURL=* +CONNECTIONSMAP__1__CONNECTION=AGENTIC +CONNECTIONSMAP__1__SERVICEURL=agentic +""".format( + abs_oauth_connection_name=DEFAULTS.abs_oauth_connection_name, + obo_connection_name=DEFAULTS.obo_connection_name, + auth_handler_id=DEFAULTS.auth_handler_id, + auth_handler_title=DEFAULTS.auth_handler_title, + auth_handler_text=DEFAULTS.auth_handler_text, + agentic_abs_oauth_connection_name=DEFAULTS.agentic_abs_oauth_connection_name, + agentic_obo_connection_name=DEFAULTS.agentic_obo_connection_name, + agentic_auth_handler_id=DEFAULTS.agentic_auth_handler_id, + agentic_auth_handler_title=DEFAULTS.agentic_auth_handler_title, + agentic_auth_handler_text=DEFAULTS.agentic_auth_handler_text, +) + + +def TEST_AGENTIC_ENV(): + return create_env_var_dict(_TEST_AGENTIC_ENV_RAW) + + +def TEST_AGENTIC_ENV_DICT(): + return load_configuration_from_env(TEST_AGENTIC_ENV()) diff --git a/tests/_common/data/configs/test_auth_config.py b/tests/_common/data/configs/test_auth_config.py new file mode 100644 index 00000000..67152bad --- /dev/null +++ b/tests/_common/data/configs/test_auth_config.py @@ -0,0 +1,28 @@ +from microsoft_agents.activity import load_configuration_from_env + +from ...create_env_var_dict import create_env_var_dict +from ..test_defaults import TEST_DEFAULTS + +DEFAULTS = TEST_DEFAULTS() + +_TEST_ENV_RAW = """ +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__AZUREBOTOAUTHCONNECTIONNAME={abs_oauth_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__OBOCONNECTIONNAME={obo_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TITLE={auth_handler_title} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TEXT={auth_handler_text} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TYPE=UserAuthorization +""".format( + abs_oauth_connection_name=DEFAULTS.abs_oauth_connection_name, + obo_connection_name=DEFAULTS.obo_connection_name, + auth_handler_id=DEFAULTS.auth_handler_id, + auth_handler_title=DEFAULTS.auth_handler_title, + auth_handler_text=DEFAULTS.auth_handler_text, +) + + +def TEST_ENV(): + return create_env_var_dict(_TEST_ENV_RAW) + + +def TEST_ENV_DICT(): + return load_configuration_from_env(TEST_ENV()) diff --git a/tests/_common/data/test_defaults.py b/tests/_common/data/test_defaults.py index 7422d253..58164e12 100644 --- a/tests/_common/data/test_defaults.py +++ b/tests/_common/data/test_defaults.py @@ -14,7 +14,31 @@ def __init__(self): self.user_id = "__user_id" self.bot_url = "https://botframework.com" self.ms_app_id = "__ms_app_id" - self.abs_oauth_connection_name = "__connection_name" - self.missing_abs_oauth_connection_name = "__missing_connection_name" + + # Auth Handler Settings + self.abs_oauth_connection_name = "connection_name" + self.obo_connection_name = "SERVICE_CONNECTION" + self.auth_handler_id = "auth_handler_id" + self.auth_handler_title = "auth_handler_title" + self.auth_handler_text = "auth_handler_text" + + # Connections Settings + self.connections_default_tenant_id = "service-tenant-id" + self.connections_default_client_id = "service-client-id" + self.connections_default_client_secret = "service-client-secret" + self.connections_agentic_tenant_id = "agentic-tenant-id" + self.connections_agentic_client_id = "agentic-client-id" + self.connections_agentic_client_secret = "agentic-client-secret" + + self.agentic_abs_oauth_connection_name = "agentic_connection_name" + self.agentic_obo_connection_name = "SERVICE_CONNECTION" + self.agentic_auth_handler_id = "agentic_auth_handler_id" + self.agentic_auth_handler_title = "agentic_auth_handler_title" + self.agentic_auth_handler_text = "agentic_auth_handler_text" + + self.agentic_instance_id = "agentic_instance_id" + self.agentic_user_id = "agentic_user_id" + + self.missing_abs_oauth_connection_name = "missing_connection_name" self.auth_handlers = [AuthHandler()] diff --git a/tests/_common/data/test_flow_data.py b/tests/_common/data/test_flow_data.py index 6cb0c7c3..ac41c0f8 100644 --- a/tests/_common/data/test_flow_data.py +++ b/tests/_common/data/test_flow_data.py @@ -1,6 +1,6 @@ from datetime import datetime -from microsoft_agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag +from microsoft_agents.hosting.core._oauth import _FlowState, _FlowStateTag from tests._common.storage import MockStoreItem from tests._common.data.test_defaults import TEST_DEFAULTS @@ -18,69 +18,62 @@ class TEST_FLOW_DATA: def __init__(self): - self.not_started = FlowState( + self.not_started = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.NOT_STARTED, + tag=_FlowStateTag.NOT_STARTED, attempts_remaining=1, - user_token="____", expiration=datetime.now().timestamp() + 1000000, ) - self.started = FlowState( + self.started = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.BEGIN, + tag=_FlowStateTag.BEGIN, attempts_remaining=1, - user_token="____", expiration=datetime.now().timestamp() + 1000000, ) - self.started_one_retry = FlowState( + self.started_one_retry = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.BEGIN, + tag=_FlowStateTag.BEGIN, attempts_remaining=2, - user_token="____", expiration=datetime.now().timestamp() + 1000000, ) - self.active = FlowState( + self.active = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.CONTINUE, + tag=_FlowStateTag.CONTINUE, attempts_remaining=2, - user_token="__token", expiration=datetime.now().timestamp() + 1000000, ) - self.active_one_retry = FlowState( + self.active_one_retry = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.CONTINUE, + tag=_FlowStateTag.CONTINUE, attempts_remaining=1, - user_token="__token", expiration=datetime.now().timestamp() + 1000000, ) - self.active_exp = FlowState( + self.active_exp = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.CONTINUE, + tag=_FlowStateTag.CONTINUE, attempts_remaining=2, - user_token="__token", expiration=datetime.now().timestamp(), ) - self.completed = FlowState( + self.completed = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.COMPLETE, + tag=_FlowStateTag.COMPLETE, attempts_remaining=2, - user_token="test_token", expiration=datetime.now().timestamp() + 1000000, ) - self.fail_by_attempts = FlowState( + self.fail_by_attempts = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.FAILURE, + tag=_FlowStateTag.FAILURE, attempts_remaining=0, expiration=datetime.now().timestamp() + 1000000, ) - self.fail_by_exp = FlowState( + self.fail_by_exp = _FlowState( **DEF_FLOW_ARGS, - tag=FlowStateTag.FAILURE, + tag=_FlowStateTag.FAILURE, attempts_remaining=2, expiration=0, ) diff --git a/tests/_common/data/test_storage_data.py b/tests/_common/data/test_storage_data.py index 35b6a8d1..31cb0030 100644 --- a/tests/_common/data/test_storage_data.py +++ b/tests/_common/data/test_storage_data.py @@ -1,8 +1,9 @@ +from microsoft_agents.hosting.core._oauth import _FlowState + from tests._common.storage import MockStoreItem from .test_flow_data import ( TEST_FLOW_DATA, - FlowState, update_flow_state_handler, flow_key, ) @@ -39,7 +40,7 @@ def __init__(self): def get_init_data(self): data = self.dict.copy() for key, value in data.items(): - data[key] = value.model_copy() if isinstance(value, FlowState) else value + data[key] = value.model_copy() if isinstance(value, _FlowState) else value return data diff --git a/tests/_common/fixtures/flow_state_fixtures.py b/tests/_common/fixtures/flow_state_fixtures.py index 4ce502d8..345235be 100644 --- a/tests/_common/fixtures/flow_state_fixtures.py +++ b/tests/_common/fixtures/flow_state_fixtures.py @@ -1,6 +1,6 @@ import pytest -from microsoft_agents.hosting.core import FlowStateTag +from microsoft_agents.hosting.core._oauth import _FlowStateTag from tests._common.data import TEST_FLOW_DATA @@ -24,7 +24,7 @@ def inactive_flow_state(self, request): params=[ flow_state for flow_state in FLOW_STATES.inactive_flows() - if flow_state.tag != FlowStateTag.COMPLETE + if flow_state.tag != _FlowStateTag.COMPLETE ] ) def inactive_flow_state_not_completed(self, request): @@ -38,7 +38,7 @@ def active_flow_state(self, request): params=[ flow_state for flow_state in FLOW_STATES.inactive_flows() - if flow_state.tag != FlowStateTag.COMPLETE + if flow_state.tag != _FlowStateTag.COMPLETE ] ) def sample_inactive_flow_state_not_completed(self, request): diff --git a/tests/_common/mock_utils.py b/tests/_common/mock_utils.py new file mode 100644 index 00000000..fec35582 --- /dev/null +++ b/tests/_common/mock_utils.py @@ -0,0 +1,17 @@ +def mock_instance(mocker, cls, methods={}, default_mock_type=None, **kwargs): + """Create a mock instance of a class with specified methods mocked.""" + if not default_mock_type: + default_mock_type = mocker.AsyncMock + instance = mocker.Mock(spec=cls, **kwargs) + for method_name, return_value in methods.items(): + if not isinstance(return_value, mocker.Mock) and not isinstance( + return_value, mocker.AsyncMock + ): + return_value = default_mock_type(return_value=return_value) + setattr(instance, method_name, return_value) + return instance + + +def mock_class(mocker, cls, instance): + """Replace a class with a mock instance.""" + mocker.patch.object(cls, new=instance) diff --git a/tests/_common/testing_objects/__init__.py b/tests/_common/testing_objects/__init__.py index 7e36b7e2..875165dc 100644 --- a/tests/_common/testing_objects/__init__.py +++ b/tests/_common/testing_objects/__init__.py @@ -6,6 +6,10 @@ mock_class_OAuthFlow, mock_UserTokenClient, mock_class_UserTokenClient, + mock_class_UserAuthorization, + mock_class_AgenticUserAuthorization, + mock_class_Authorization, + agentic_mock_class_MsalAuth, ) from .testing_authorization import TestingAuthorization @@ -26,4 +30,8 @@ "TestingTokenProvider", "TestingUserTokenClient", "TestingAdapter", + "mock_class_UserAuthorization", + "mock_class_AgenticUserAuthorization", + "mock_class_Authorization", + "agentic_mock_class_MsalAuth", ] diff --git a/tests/_common/testing_objects/adapters/testing_adapter.py b/tests/_common/testing_objects/adapters/testing_adapter.py index f753f78b..38021bc4 100644 --- a/tests/_common/testing_objects/adapters/testing_adapter.py +++ b/tests/_common/testing_objects/adapters/testing_adapter.py @@ -497,7 +497,7 @@ def create_turn_context( turn_context = TurnContext(self, activity) turn_context.services["UserTokenClient"] = self._user_token_client - turn_context.identity = identity or self.claims_identity + turn_context._identity = identity or self.claims_identity return turn_context diff --git a/tests/_common/testing_objects/http/testing_channel_service_client_factory.py b/tests/_common/testing_objects/http/testing_channel_service_client_factory.py deleted file mode 100644 index 8dd54e97..00000000 --- a/tests/_common/testing_objects/http/testing_channel_service_client_factory.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Optional - -from microsoft_agents.hosting.core.authorization import ( - AuthenticationConstants, - AnonymousTokenProvider, - ClaimsIdentity, - Connections, -) -from microsoft_agents.hosting.core.authorization import AccessTokenProviderBase -from microsoft_agents.hosting.core.connector import ConnectorClientBase -from microsoft_agents.hosting.core.connector.client import UserTokenClient -from microsoft_agents.hosting.core.connector.teams import TeamsConnectorClient - -from .channel_service_client_factory_base import ChannelServiceClientFactoryBase -from .testing_connector_client import TestingConnectorClient - - -class TestingRestChannelServiceClientFactory(ChannelServiceClientFactoryBase): - _ANONYMOUS_TOKEN_PROVIDER = AnonymousTokenProvider() - - def __init__( - self, - mocker, - connection_manager: Connections, - token_service_endpoint=AuthenticationConstants.AGENTS_SDK_OAUTH_URL, - token_service_audience=AuthenticationConstants.AGENTS_SDK_SCOPE, - connector_client_class: type[BaseConnectorClient] = TestingConnectorClient, - user_token_client_class: type[BaseUserTokenClient] = TestingUserTokenClient, - ) -> None: - self._mocker = mocker - self._connection_manager = connection_manager - self._token_service_endpoint = token_service_endpoint - self._token_service_audience = token_service_audience - self._connector_client_class = connector_client_class - self._user_token_client_class = user_token_client_class - - async def create_connector_client( - self, - claims_identity: ClaimsIdentity, - service_url: str, - audience: str, - scopes: Optional[list[str]] = None, - use_anonymous: bool = False, - ) -> ConnectorClientBase: - if not service_url: - raise TypeError( - "RestChannelServiceClientFactory.create_connector_client: service_url can't be None or Empty" - ) - if not audience: - raise TypeError( - "RestChannelServiceClientFactory.create_connector_client: audience can't be None or Empty" - ) - - token_provider: AccessTokenProviderBase = ( - self._connection_manager.get_token_provider(claims_identity, service_url) - if not use_anonymous - else self._ANONYMOUS_TOKEN_PROVIDER - ) - - token = await token_provider.get_access_token( - audience, scopes or [f"{audience}/.default"] - ) - - return self._connector_client_class( - endpoint=service_url, - token=token, - ) - - async def create_user_token_client( - self, claims_identity: ClaimsIdentity, use_anonymous: bool = False - ) -> UserTokenClient: - token_provider = ( - self._connection_manager.get_token_provider( - claims_identity, self._token_service_endpoint - ) - if not use_anonymous - else self._ANONYMOUS_TOKEN_PROVIDER - ) - - token = await token_provider.get_access_token( - self._token_service_audience, [f"{self._token_service_audience}/.default"] - ) - return self._user_token_client_class( - endpoint=self._token_service_endpoint, - token=token, - ) diff --git a/tests/_common/testing_objects/http/testing_client_session.py b/tests/_common/testing_objects/http/testing_client_session.py deleted file mode 100644 index a00a89de..00000000 --- a/tests/_common/testing_objects/http/testing_client_session.py +++ /dev/null @@ -1,2 +0,0 @@ -class TestingClientSessionBase: - pass diff --git a/tests/_common/testing_objects/http/testing_connector_client.py b/tests/_common/testing_objects/http/testing_connector_client.py deleted file mode 100644 index fd797814..00000000 --- a/tests/_common/testing_objects/http/testing_connector_client.py +++ /dev/null @@ -1,40 +0,0 @@ -from microsft_agents.hosting.core import ( - AgentAuthConfiguration, - AccessTokenProviderBase, - TeamsConnectorClient, -) - -from tests._common.testing_objects.http.testing_client_session import ( - TestingClientSession, -) - - -class TestingConnectorClient(TeamsConnectorClient): - """Teams Connector Client for interacting with Teams-specific APIs.""" - - @classmethod - async def create_client_with_auth_async( - cls, - base_url: str, - auth_config: AgentAuthConfiguration, - auth_provider: AccessTokenProviderBase, - scope: str, - ) -> "TeamsConnectorClient": - """ - Creates a new instance of TeamsConnectorClient with authentication. - - :param base_url: The base URL for the API. - :param auth_config: The authentication configuration. - :param auth_provider: The authentication provider. - :param scope: The scope for the authentication token. - :return: A new instance of TeamsConnectorClient. - """ - session = TestingClientSession( - base_url=base_url, headers={"Accept": "application/json"} - ) - - token = await auth_provider.get_access_token(auth_config, scope) - if len(token) > 1: - session.headers.update({"Authorization": f"Bearer {token}"}) - - return cls(session) diff --git a/tests/_common/testing_objects/mocks/__init__.py b/tests/_common/testing_objects/mocks/__init__.py index 786a79c8..a6f7c85d 100644 --- a/tests/_common/testing_objects/mocks/__init__.py +++ b/tests/_common/testing_objects/mocks/__init__.py @@ -1,10 +1,20 @@ -from .mock_msal_auth import MockMsalAuth +from .mock_msal_auth import MockMsalAuth, agentic_mock_class_MsalAuth from .mock_oauth_flow import mock_OAuthFlow, mock_class_OAuthFlow from .mock_user_token_client import mock_UserTokenClient, mock_class_UserTokenClient +from .mock_authorization import ( + mock_class_UserAuthorization, + mock_class_AgenticUserAuthorization, + mock_class_Authorization, +) __all__ = [ "MockMsalAuth", "mock_OAuthFlow", "mock_class_OAuthFlow", "mock_UserTokenClient", + "mock_class_UserTokenClient", + "mock_class_UserAuthorization", + "mock_class_AgenticUserAuthorization", + "mock_class_Authorization", + "agentic_mock_class_MsalAuth", ] diff --git a/tests/_common/testing_objects/mocks/mock_authorization.py b/tests/_common/testing_objects/mocks/mock_authorization.py new file mode 100644 index 00000000..4e1afdee --- /dev/null +++ b/tests/_common/testing_objects/mocks/mock_authorization.py @@ -0,0 +1,50 @@ +from microsoft_agents.activity import TokenResponse + +from microsoft_agents.hosting.core import Authorization +from microsoft_agents.hosting.core.app.oauth import ( + _UserAuthorization, + AgenticUserAuthorization, + _SignInResponse, +) + + +def mock_class_UserAuthorization( + mocker, sign_in_return=None, get_refreshed_token_return=None +): + if sign_in_return is None: + sign_in_return = _SignInResponse() + if get_refreshed_token_return is None: + get_refreshed_token_return = TokenResponse() + mocker.patch.object(_UserAuthorization, "_sign_in", return_value=sign_in_return) + mocker.patch.object(_UserAuthorization, "_sign_out") + mocker.patch.object( + _UserAuthorization, + "get_refreshed_token", + return_value=get_refreshed_token_return, + ) + + +def mock_class_AgenticUserAuthorization( + mocker, sign_in_return=None, get_refreshed_token_return=None +): + if sign_in_return is None: + sign_in_return = _SignInResponse() + if get_refreshed_token_return is None: + get_refreshed_token_return = TokenResponse() + mocker.patch.object( + AgenticUserAuthorization, "_sign_in", return_value=sign_in_return + ) + mocker.patch.object(AgenticUserAuthorization, "_sign_out") + mocker.patch.object( + AgenticUserAuthorization, + "get_refreshed_token", + return_value=get_refreshed_token_return, + ) + + +def mock_class_Authorization(mocker, start_or_continue_sign_in_return=False): + mocker.patch.object( + Authorization, + "_start_or_continue_sign_in", + return_value=start_or_continue_sign_in_return, + ) diff --git a/tests/_common/testing_objects/mocks/mock_msal_auth.py b/tests/_common/testing_objects/mocks/mock_msal_auth.py index 44a94025..f9a046b7 100644 --- a/tests/_common/testing_objects/mocks/mock_msal_auth.py +++ b/tests/_common/testing_objects/mocks/mock_msal_auth.py @@ -2,17 +2,23 @@ from microsoft_agents.hosting.core.authorization import AgentAuthConfiguration +# used by MsalAuth tests class MockMsalAuth(MsalAuth): """ Mock object for MsalAuth """ - def __init__(self, mocker, client_type): + def __init__( + self, + mocker, + client_type, + acquire_token_for_client_return={"access_token": "token"}, + ): super().__init__(AgentAuthConfiguration()) mock_client = mocker.Mock(spec=client_type) mock_client.acquire_token_for_client = mocker.Mock( - return_value={"access_token": "token"} + return_value=acquire_token_for_client_return ) mock_client.acquire_token_on_behalf_of = mocker.Mock( return_value={"access_token": "token"} @@ -20,3 +26,24 @@ def __init__(self, mocker, client_type): self.mock_client = mock_client self._create_client_application = mocker.Mock(return_value=self.mock_client) + + +def agentic_mock_class_MsalAuth( + mocker, + get_agentic_application_token_return=None, + get_agentic_instance_token_return=None, + get_agentic_user_token_return=None, +): + mocker.patch.object( + MsalAuth, + "get_agentic_application_token", + return_value=get_agentic_application_token_return, + ) + mocker.patch.object( + MsalAuth, + "get_agentic_instance_token", + return_value=get_agentic_instance_token_return, + ) + mocker.patch.object( + MsalAuth, "get_agentic_user_token", return_value=get_agentic_user_token_return + ) diff --git a/tests/_common/testing_objects/mocks/mock_oauth_flow.py b/tests/_common/testing_objects/mocks/mock_oauth_flow.py index 82e78328..53a066d3 100644 --- a/tests/_common/testing_objects/mocks/mock_oauth_flow.py +++ b/tests/_common/testing_objects/mocks/mock_oauth_flow.py @@ -1,5 +1,5 @@ from microsoft_agents.activity import TokenResponse -from microsoft_agents.hosting.core import OAuthFlow +from microsoft_agents.hosting.core._oauth import _OAuthFlow from tests._common.data import TEST_DEFAULTS @@ -17,13 +17,15 @@ def mock_OAuthFlow( # mock_oauth_flow_class.sign_out = mocker.AsyncMock() if isinstance(get_user_token_return, str): get_user_token_return = TokenResponse(token=get_user_token_return) - mocker.patch.object(OAuthFlow, "get_user_token", return_value=get_user_token_return) - mocker.patch.object(OAuthFlow, "sign_out") mocker.patch.object( - OAuthFlow, "begin_or_continue_flow", return_value=begin_or_continue_flow_return + _OAuthFlow, "get_user_token", return_value=get_user_token_return ) - mocker.patch.object(OAuthFlow, "begin_flow", return_value=begin_flow_return) - mocker.patch.object(OAuthFlow, "continue_flow", return_value=continue_flow_return) + mocker.patch.object(_OAuthFlow, "sign_out") + mocker.patch.object( + _OAuthFlow, "begin_or_continue_flow", return_value=begin_or_continue_flow_return + ) + mocker.patch.object(_OAuthFlow, "begin_flow", return_value=begin_flow_return) + mocker.patch.object(_OAuthFlow, "continue_flow", return_value=continue_flow_return) def mock_class_OAuthFlow( @@ -34,7 +36,7 @@ def mock_class_OAuthFlow( continue_flow_return=None, ): mocker.patch( - "microsoft_agents.hosting.core.OAuthFlow", + "microsoft_agents.hosting.core._oauth._OAuthFlow", new=mock_OAuthFlow( mocker, get_user_token_return=get_user_token_return, diff --git a/tests/_common/testing_objects/testing_token_provider.py b/tests/_common/testing_objects/testing_token_provider.py index 28baffc9..66dcf002 100644 --- a/tests/_common/testing_objects/testing_token_provider.py +++ b/tests/_common/testing_objects/testing_token_provider.py @@ -38,7 +38,7 @@ async def get_access_token( """ return f"{self.name}-token" - async def aquire_token_on_behalf_of( + async def acquire_token_on_behalf_of( self, scopes: list[str], user_assertion: str ) -> str: """ diff --git a/tests/_integration/test_quickstart.py b/tests/_integration/test_quickstart.py index 32fac1bf..bbb3f997 100644 --- a/tests/_integration/test_quickstart.py +++ b/tests/_integration/test_quickstart.py @@ -1,37 +1,37 @@ -import pytest +# import pytest -from tests._integration.common.testing_environment import ( - TestingEnvironment, - MockTestingEnvironment, -) -from tests._integration.scenarios.quickstart import main +# from tests._integration.common.testing_environment import ( +# TestingEnvironment, +# MockTestingEnvironment, +# ) +# from tests._integration.scenarios.quickstart import main -class _TestQuickstart: - @pytest.fixture - def testenv(self, mocker) -> TestingEnvironment: - raise NotImplementedError() +# class _TestQuickstart: +# @pytest.fixture +# def testenv(self, mocker) -> TestingEnvironment: +# raise NotImplementedError() - # @pytest.fixture - # def client(self, testenv) -> TestClient: - # return TestClient(testenv.adapter) +# # @pytest.fixture +# # def client(self, testenv) -> TestClient: +# # return TestClient(testenv.adapter) - @pytest.mark.asyncio - async def test_quickstart(self, testenv): - main(testenv) - # testenv.adapter.send_activity("Hello World") +# @pytest.mark.asyncio +# async def test_quickstart(self, testenv): +# main(testenv) +# # testenv.adapter.send_activity("Hello World") -# class TestQuickstartMultipleEnvs(_TestQuickstart): +# # class TestQuickstartMultipleEnvs(_TestQuickstart): -# @pytest.fixture( -# params=[MockTestingEnvironment, SampleEnvironment], -# ) -# def testenv(self, mocker, request) -> TestingEnvironment: -# return request.param(mocker) +# # @pytest.fixture( +# # params=[MockTestingEnvironment, SampleEnvironment], +# # ) +# # def testenv(self, mocker, request) -> TestingEnvironment: +# # return request.param(mocker) -class TestQuickstartMockEnv(_TestQuickstart): - @pytest.fixture - def testenv(self, mocker) -> TestingEnvironment: - return MockTestingEnvironment(mocker) +# class TestQuickstartMockEnv(_TestQuickstart): +# @pytest.fixture +# def testenv(self, mocker) -> TestingEnvironment: +# return MockTestingEnvironment(mocker) diff --git a/tests/activity/test_activity.py b/tests/activity/test_activity.py index 179b40df..40d695fe 100644 --- a/tests/activity/test_activity.py +++ b/tests/activity/test_activity.py @@ -16,10 +16,14 @@ AIEntity, Place, Thing, + RoleTypes, ) from tests.activity._common.my_channel_data import MyChannelData from tests.activity._common.testing_activity import create_test_activity +from tests._common.data import TEST_DEFAULTS + +DEFAULTS = TEST_DEFAULTS() def helper_validate_recipient_and_from( @@ -368,3 +372,74 @@ def test_get_mentions(self): Mention(text="Hello"), Entity(type="mention", text="Another mention"), ] + + +class TestActivityAgenticOps: + + @pytest.fixture(params=[RoleTypes.user, RoleTypes.skill, RoleTypes.agent]) + def non_agentic_role(self, request): + return request.param + + @pytest.fixture(params=[RoleTypes.agentic_user, RoleTypes.agentic_identity]) + def agentic_role(self, request): + return request.param + + @pytest.mark.parametrize( + "role, expected", + [ + [RoleTypes.user, False], + [RoleTypes.agent, False], + [RoleTypes.skill, False], + [RoleTypes.agentic_user, True], + [RoleTypes.agentic_identity, True], + ], + ) + def test_is_agentic_request(self, role, expected): + activity = Activity( + type="message", recipient=ChannelAccount(id="bot", name="bot", role=role) + ) + assert activity.is_agentic_request() == expected + + def test_get_agentic_instance_id_is_agentic(self, mocker, agentic_role): + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + assert activity.get_agentic_instance_id() == DEFAULTS.agentic_instance_id + + def test_get_agentic_instance_id_not_agentic(self, non_agentic_role): + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=non_agentic_role, + ), + ) + assert activity.get_agentic_instance_id() is None + + def test_get_agentic_user_is_agentic(self, agentic_role): + activity = Activity( + type="message", + recipient=ChannelAccount( + id=DEFAULTS.agentic_user_id, + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + assert activity.get_agentic_user() == DEFAULTS.agentic_user_id + + def test_get_agentic_user_not_agentic(self, non_agentic_role): + activity = Activity( + type="message", + recipient=ChannelAccount( + id=DEFAULTS.agentic_user_id, + agentic_app_id=DEFAULTS.agentic_instance_id, + role=non_agentic_role, + ), + ) + assert activity.get_agentic_user() is None diff --git a/tests/activity/test_load_configuration.py b/tests/activity/test_load_configuration.py new file mode 100644 index 00000000..eb9dccac --- /dev/null +++ b/tests/activity/test_load_configuration.py @@ -0,0 +1,103 @@ +from microsoft_agents.activity import load_configuration_from_env + +from tests._common import create_env_var_dict +from tests._common.data import TEST_DEFAULTS + +DEFAULTS = TEST_DEFAULTS() + +ENV_DICT = { + "CONNECTIONS": { + "SERVICE_CONNECTION": { + "SETTINGS": { + "TENANTID": DEFAULTS.connections_default_tenant_id, + "CLIENTID": DEFAULTS.connections_default_client_id, + "CLIENTSECRET": DEFAULTS.connections_default_client_secret, + } + }, + "AGENTIC": { + "SETTINGS": { + "TENANTID": DEFAULTS.connections_agentic_tenant_id, + "CLIENTID": DEFAULTS.connections_agentic_client_id, + "CLIENTSECRET": DEFAULTS.connections_agentic_client_secret, + } + }, + }, + "AGENTAPPLICATION": { + "USERAUTHORIZATION": { + "HANDLERS": { + DEFAULTS.auth_handler_id: { + "SETTINGS": { + "AZUREBOTOAUTHCONNECTIONNAME": DEFAULTS.abs_oauth_connection_name, + "OBOCONNECTIONNAME": DEFAULTS.obo_connection_name, + "TITLE": DEFAULTS.auth_handler_title, + "TEXT": DEFAULTS.auth_handler_text, + "TYPE": "UserAuthorization", + } + }, + DEFAULTS.agentic_auth_handler_id: { + "SETTINGS": { + "AZUREBOTOAUTHCONNECTIONNAME": DEFAULTS.agentic_abs_oauth_connection_name, + "OBOCONNECTIONNAME": DEFAULTS.agentic_obo_connection_name, + "TITLE": DEFAULTS.agentic_auth_handler_title, + "TEXT": DEFAULTS.agentic_auth_handler_text, + "TYPE": "AgenticAuthorization", + } + }, + } + }, + }, + "CONNECTIONSMAP": [ + {"CONNECTION": "SERVICE_CONNECTION", "SERVICEURL": "*"}, + {"CONNECTION": "AGENTIC", "SERVICEURL": "agentic"}, + ], +} + +ENV_RAW = """ +CONNECTIONS__SERVICE_CONNECTION__SETTINGS__TENANTID={connections_default_tenant_id} +CONNECTIONS__SERVICE_CONNECTION__SETTINGS__CLIENTID={connections_default_client_id} +CONNECTIONS__SERVICE_CONNECTION__SETTINGS__CLIENTSECRET={connections_default_client_secret} + +CONNECTIONS__AGENTIC__SETTINGS__TENANTID={connections_agentic_tenant_id} +CONNECTIONS__AGENTIC__SETTINGS__CLIENTID={connections_agentic_client_id} +CONNECTIONS__AGENTIC__SETTINGS__CLIENTSECRET={connections_agentic_client_secret} + +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__AZUREBOTOAUTHCONNECTIONNAME={abs_oauth_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__OBOCONNECTIONNAME={obo_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TITLE={auth_handler_title} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TEXT={auth_handler_text} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{auth_handler_id}__SETTINGS__TYPE=UserAuthorization + +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__AZUREBOTOAUTHCONNECTIONNAME={agentic_abs_oauth_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__OBOCONNECTIONNAME={agentic_obo_connection_name} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__TITLE={agentic_auth_handler_title} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__TEXT={agentic_auth_handler_text} +AGENTAPPLICATION__USERAUTHORIZATION__HANDLERS__{agentic_auth_handler_id}__SETTINGS__TYPE=AgenticAuthorization + +CONNECTIONSMAP__0__CONNECTION=SERVICE_CONNECTION +CONNECTIONSMAP__0__SERVICEURL=* +CONNECTIONSMAP__1__CONNECTION=AGENTIC +CONNECTIONSMAP__1__SERVICEURL=agentic +""".format( + connections_default_tenant_id=DEFAULTS.connections_default_tenant_id, + connections_default_client_id=DEFAULTS.connections_default_client_id, + connections_default_client_secret=DEFAULTS.connections_default_client_secret, + connections_agentic_tenant_id=DEFAULTS.connections_agentic_tenant_id, + connections_agentic_client_id=DEFAULTS.connections_agentic_client_id, + connections_agentic_client_secret=DEFAULTS.connections_agentic_client_secret, + abs_oauth_connection_name=DEFAULTS.abs_oauth_connection_name, + obo_connection_name=DEFAULTS.obo_connection_name, + auth_handler_id=DEFAULTS.auth_handler_id, + auth_handler_title=DEFAULTS.auth_handler_title, + auth_handler_text=DEFAULTS.auth_handler_text, + agentic_abs_oauth_connection_name=DEFAULTS.agentic_abs_oauth_connection_name, + agentic_obo_connection_name=DEFAULTS.agentic_obo_connection_name, + agentic_auth_handler_id=DEFAULTS.agentic_auth_handler_id, + agentic_auth_handler_title=DEFAULTS.agentic_auth_handler_title, + agentic_auth_handler_text=DEFAULTS.agentic_auth_handler_text, +) + + +def test_load_configuration_from_env(): + input_dict = create_env_var_dict(ENV_RAW) + config = load_configuration_from_env(input_dict) + assert config == ENV_DICT diff --git a/tests/authentication_msal/_data.py b/tests/authentication_msal/_data.py new file mode 100644 index 00000000..92e5ae35 --- /dev/null +++ b/tests/authentication_msal/_data.py @@ -0,0 +1,75 @@ +ENV_CONFIG = { + "CONNECTIONS": { + "SERVICE_CONNECTION": { + "SETTINGS": { + "TENANTID": "test-tenant-id-SERVICE_CONNECTION", + "CLIENTID": "test-client-id-SERVICE_CONNECTION", + "CLIENTSECRET": "test-client-secret-SERVICE_CONNECTION", + } + }, + "AGENTIC": { + "SETTINGS": { + "TENANTID": "test-tenant-id-AGENTIC", + "CLIENTID": "test-client-id-AGENTIC", + "CLIENTSECRET": "test-client-secret-AGENTIC", + } + }, + "MISC": { + "SETTINGS": { + "TENANTID": "test-tenant-id-MISC", + "CLIENTID": "test-client-id-MISC", + "CLIENTSECRET": "test-client-secret-MISC", + } + }, + }, + "AGENTAPPLICATION": { + "USERAUTHORIZATION": { + "HANDLERS": { + "graph": { + "SETTINGS": { + "AZUREBOTOAUTHCONNECTIONNAME": "graph", + "OBOCONNECTIONNAME": "MISC", + "SCOPES": ["User.Read"], + "TITLE": "Sign in with Microsoft", + "TEXT": "Sign in with your Microsoft account", + "TYPE": "UserAuthorization", + } + }, + "github": { + "SETTINGS": { + "AZUREBOTOAUTHCONNECTIONNAME": "github", + "OBOCONNECTIONNAME": "SERVICE_CONNECTION", + "TYPE": "UserAuthorization", + } + }, + "agentic": { + "SETTINGS": { + "AZUREBOTOAUTHCONNECTIONNAME": "AGENTIC", + "OBOCONNECTIONNAME": "MISC", + "SCOPES": ["https://graph.microsoft.com/.default"], + "TITLE": "Sign in with Agentic", + "TEXT": "Sign in with your Agentic account", + "TYPE": "AgenticUserAuthorization", + } + }, + } + } + }, + "CONNECTIONSMAP": [ + { + "CONNECTION": "AGENTIC", + "SERVICEURL": "agentic", + }, + {"CONNECTION": "MISC", "AUDIENCE": "api://misc", "SERVICEURL": "*"}, + { + "CONNECTION": "MISC", + "AUDIENCE": "api://misc_other", + }, + { + "CONNECTION": "SERVICE_CONNECTION", + "AUDIENCE": "api://service", + "SERVICEURL": "https://service*", + }, + {"CONNECTION": "MISC", "SERVICEURL": "https://microsoft.com/*"}, + ], +} diff --git a/tests/authentication_msal/test_msal_auth.py b/tests/authentication_msal/test_msal_auth.py index 21576a81..7198d190 100644 --- a/tests/authentication_msal/test_msal_auth.py +++ b/tests/authentication_msal/test_msal_auth.py @@ -1,6 +1,7 @@ import pytest from msal import ManagedIdentityClient, ConfidentialClientApplication +from microsoft_agents.authentication.msal import MsalAuth from microsoft_agents.hosting.core import Connections from tests._common.testing_objects import MockMsalAuth @@ -36,11 +37,11 @@ async def test_get_access_token_confidential(self, mocker): ) @pytest.mark.asyncio - async def test_aquire_token_on_behalf_of_managed_identity(self, mocker): + async def test_acquire_token_on_behalf_of_managed_identity(self, mocker): mock_auth = MockMsalAuth(mocker, ManagedIdentityClient) try: - await mock_auth.aquire_token_on_behalf_of( + await mock_auth.acquire_token_on_behalf_of( scopes=["test-scope"], user_assertion="test-assertion" ) except NotImplementedError: @@ -49,13 +50,13 @@ async def test_aquire_token_on_behalf_of_managed_identity(self, mocker): assert False @pytest.mark.asyncio - async def test_aquire_token_on_behalf_of_confidential(self, mocker): + async def test_acquire_token_on_behalf_of_confidential(self, mocker): mock_auth = MockMsalAuth(mocker, ConfidentialClientApplication) mock_auth._create_client_application = mocker.Mock( return_value=mock_auth.mock_client ) - token = await mock_auth.aquire_token_on_behalf_of( + token = await mock_auth.acquire_token_on_behalf_of( scopes=["test-scope"], user_assertion="test-assertion" ) @@ -63,3 +64,45 @@ async def test_aquire_token_on_behalf_of_confidential(self, mocker): mock_auth.mock_client.acquire_token_on_behalf_of.assert_called_with( scopes=["test-scope"], user_assertion="test-assertion" ) + + +# class TestMsalAuthAgentic: + +# @pytest.mark.asyncio +# async def test_get_agentic_user_token_data_flow(self, mocker): +# agent_app_instance_id = "test-agent-app-id" +# app_token = "app-token" +# instance_token = "instance-token" +# agent_user_token = "agent-token" +# upn = "test-upn" +# scopes = ["user.read"] + +# mocker.patch.object(MsalAuth, "get_agentic_instance_token", return_value=[instance_token, app_token]) + +# mock_auth = MockMsalAuth(mocker, ConfidentialClientApplication) +# mocker.patch.object(ConfidentialClientApplication, "__new__", return_value=mocker.Mock(spec=ConfidentialClientApplication)) + +# result = await mock_auth.get_agentic_user_token(agent_app_instance_id, upn, scopes) +# mock_auth.get_agentic_instance_token.assert_called_once_with(agent_app_instance_id) + +# assert result == agent_user_token + +# @pytest.mark.asyncio +# async def test_get_agentic_user_token_failure(self, mocker): +# agent_app_instance_id = "test-agent-app-id" +# app_token = "app-token" +# instance_token = "instance-token" +# agent_user_token = "agent-token" +# upn = "test-upn" +# scopes = ["user.read"] + +# mocker.patch.object(MsalAuth, "get_agentic_instance_token", return_value=[instance_token, app_token]) + +# mock_auth = MockMsalAuth(mocker, ConfidentialClientApplication, acquire_token_for_client_return=None) +# mocker.patch.object(ConfidentialClientApplication, "__new__", return_value=mocker.Mock(spec=ConfidentialClientApplication)) + +# result = await mock_auth.get_agentic_user_token(agent_app_instance_id, upn, scopes) + +# mock_auth.get_agentic_instance_token.assert_called_once_with(agent_app_instance_id) + +# assert result is None diff --git a/tests/authentication_msal/test_msal_connection_manager.py b/tests/authentication_msal/test_msal_connection_manager.py index 723f291a..56e9d980 100644 --- a/tests/authentication_msal/test_msal_connection_manager.py +++ b/tests/authentication_msal/test_msal_connection_manager.py @@ -1,15 +1,27 @@ +import pytest + +from copy import deepcopy + from os import environ from microsoft_agents.activity import load_configuration_from_env -from microsoft_agents.hosting.core import AuthTypes +from microsoft_agents.hosting.core import AuthTypes, ClaimsIdentity from microsoft_agents.authentication.msal import MsalConnectionManager +from tests._common.create_env_var_dict import create_env_var_dict + +from ._data import ENV_CONFIG + class TestMsalConnectionManager: """ Test suite for the Msal Connection Manager """ - def test_msal_connection_manager(self): + @pytest.fixture + def config(self): + return deepcopy(ENV_CONFIG) + + def test_init_from_config(self): mock_environ = { **environ, "CONNECTIONS__SERVICE_CONNECTION__SETTINGS__TENANTID": "test-tenant-id-SERVICE_CONNECTION", @@ -33,3 +45,92 @@ def test_msal_connection_manager(self): f"https://sts.windows.net/test-tenant-id-{key}/", f"https://login.microsoftonline.com/test-tenant-id-{key}/v2.0", ] + + # TODO -> test other init paths + + @pytest.mark.parametrize( + "claims_identity, service_url", + [ + [None, ""], + [None, None], + [None, "agentic"], + [ClaimsIdentity(claims={}, is_authenticated=False), None], + [ClaimsIdentity(claims={}, is_authenticated=False), ""], + [ClaimsIdentity(claims={}, is_authenticated=False), "https://example.com"], + [ClaimsIdentity(claims={"aud": "api://misc"}, is_authenticated=False), ""], + ], + ) + def test_get_token_provider_errors(self, claims_identity, service_url): + connection_manager = MsalConnectionManager(**ENV_CONFIG) + with pytest.raises(ValueError): + connection_manager.get_token_provider(claims_identity, service_url) + + def test_get_token_provider_no_map(self, config): + del config["CONNECTIONSMAP"] + connection_manager = MsalConnectionManager(**config) + claims_identity = ClaimsIdentity( + claims={"aud": "api://misc"}, is_authenticated=True + ) + token_provider = connection_manager.get_token_provider( + claims_identity, "https://example.com" + ) + assert token_provider == connection_manager.get_default_connection() + + def test_get_token_provider_aud_match(self, config): + connection_manager = MsalConnectionManager(**config) + claims_identity = ClaimsIdentity( + claims={"aud": "api://misc"}, is_authenticated=True + ) + token_provider = connection_manager.get_token_provider( + claims_identity, "https://example.com" + ) + assert token_provider == connection_manager.get_connection("MISC") + + def test_get_token_provider_aud_and_service_url_match(self, config): + connection_manager = MsalConnectionManager(**config) + claims_identity = ClaimsIdentity( + claims={"aud": "api://service"}, is_authenticated=True + ) + token_provider = connection_manager.get_token_provider( + claims_identity, "https://service.com/api" + ) + assert token_provider == connection_manager.get_connection("SERVICE_CONNECTION") + + def test_get_token_provider_service_url_wildcard_star(self, config): + connection_manager = MsalConnectionManager(**config) + claims_identity = ClaimsIdentity( + claims={"aud": "api://misc"}, is_authenticated=False + ) + token_provider = connection_manager.get_token_provider( + claims_identity, "https://service.com/api" + ) + assert token_provider == connection_manager.get_connection("MISC") + + def test_get_token_provider_service_url_wildcard_empty(self, config): + connection_manager = MsalConnectionManager(**config) + claims_identity = ClaimsIdentity( + claims={"aud": "api://misc_other"}, is_authenticated=False + ) + token_provider = connection_manager.get_token_provider( + claims_identity, "https://service.com/api" + ) + assert token_provider == connection_manager.get_connection("MISC") + + @pytest.mark.parametrize( + "service_url, expected_connection", + [ + ["agentic", "AGENTIC"], + ["https://microsoft.com/api", "MISC"], + ["https://microsoft.com/some-url", "MISC"], + ["https://microsoft.com/", "MISC"], + ], + ) + def test_get_token_provider_service_url_match( + self, config, service_url, expected_connection + ): + connection_manager = MsalConnectionManager(**config) + claims_identity = ClaimsIdentity(claims={}, is_authenticated=False) + token_provider = connection_manager.get_token_provider( + claims_identity, service_url + ) + assert token_provider == connection_manager.get_connection(expected_connection) diff --git a/tests/hosting_core/_common/flow_state_eq.py b/tests/hosting_core/_common/flow_state_eq.py index fea6585c..3fbf152b 100644 --- a/tests/hosting_core/_common/flow_state_eq.py +++ b/tests/hosting_core/_common/flow_state_eq.py @@ -1,13 +1,13 @@ from typing import Optional -from microsoft_agents.hosting.core import FlowState +from microsoft_agents.hosting.core._oauth import _FlowState from tests._common import approx_eq # 100 ms tolerance def flow_state_eq( - fs1: Optional[FlowState], fs2: Optional[FlowState], tol: float = 0.1 + fs1: Optional[_FlowState], fs2: Optional[_FlowState], tol: float = 0.1 ) -> bool: if fs1 is None and fs2 is None: diff --git a/tests/_common/testing_objects/http/__init__.py b/tests/hosting_core/_oauth/__init__.py similarity index 100% rename from tests/_common/testing_objects/http/__init__.py rename to tests/hosting_core/_oauth/__init__.py diff --git a/tests/hosting_core/oauth/test_flow_state.py b/tests/hosting_core/_oauth/test_flow_state.py similarity index 73% rename from tests/hosting_core/oauth/test_flow_state.py rename to tests/hosting_core/_oauth/test_flow_state.py index 9e8b7266..a96468dd 100644 --- a/tests/hosting_core/oauth/test_flow_state.py +++ b/tests/hosting_core/_oauth/test_flow_state.py @@ -1,6 +1,6 @@ -from datetime import datetime import pytest -from microsoft_agents.hosting.core.oauth.flow_state import FlowState, FlowStateTag +from datetime import datetime +from microsoft_agents.hosting.core._oauth._flow_state import _FlowState, _FlowStateTag class TestFlowState: @@ -8,40 +8,40 @@ class TestFlowState: "original_flow_state, refresh_to_not_started", [ ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp(), ), True, ), ( - FlowState( - tag=FlowStateTag.BEGIN, + _FlowState( + tag=_FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp(), ), True, ), ( - FlowState( - tag=FlowStateTag.COMPLETE, + _FlowState( + tag=_FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp() - 100, ), True, ), ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp() + 1000, ), False, ), ( - FlowState( - tag=FlowStateTag.FAILURE, + _FlowState( + tag=_FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp(), ), @@ -54,47 +54,47 @@ def test_refresh(self, original_flow_state, refresh_to_not_started): new_flow_state.refresh() expected_flow_state = original_flow_state.model_copy() if refresh_to_not_started: - expected_flow_state.tag = FlowStateTag.NOT_STARTED + expected_flow_state.tag = _FlowStateTag.NOT_STARTED assert new_flow_state == expected_flow_state @pytest.mark.parametrize( "flow_state, expected", [ ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp(), ), True, ), ( - FlowState( - tag=FlowStateTag.BEGIN, + _FlowState( + tag=_FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp(), ), True, ), ( - FlowState( - tag=FlowStateTag.COMPLETE, + _FlowState( + tag=_FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp() - 100, ), True, ), ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp() + 1000, ), False, ), ( - FlowState( - tag=FlowStateTag.FAILURE, + _FlowState( + tag=_FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp() + 1000, ), @@ -109,40 +109,40 @@ def test_is_expired(self, flow_state, expected): "flow_state, expected", [ ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp(), ), True, ), ( - FlowState( - tag=FlowStateTag.BEGIN, + _FlowState( + tag=_FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp(), ), False, ), ( - FlowState( - tag=FlowStateTag.COMPLETE, + _FlowState( + tag=_FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp() - 100, ), True, ), ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp() - 100, ), False, ), ( - FlowState( - tag=FlowStateTag.FAILURE, + _FlowState( + tag=_FlowStateTag.FAILURE, attempts_remaining=-1, expiration=datetime.now().timestamp(), ), @@ -157,72 +157,72 @@ def test_reached_max_attempts(self, flow_state, expected): "flow_state, expected", [ ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=0, expiration=datetime.now().timestamp(), ), False, ), ( - FlowState( - tag=FlowStateTag.BEGIN, + _FlowState( + tag=_FlowStateTag.BEGIN, attempts_remaining=1, expiration=datetime.now().timestamp(), ), False, ), ( - FlowState( - tag=FlowStateTag.COMPLETE, + _FlowState( + tag=_FlowStateTag.COMPLETE, attempts_remaining=0, expiration=datetime.now().timestamp() - 100, ), False, ), ( - FlowState( - tag=FlowStateTag.FAILURE, + _FlowState( + tag=_FlowStateTag.FAILURE, attempts_remaining=1, expiration=datetime.now().timestamp() - 100, ), False, ), ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=2, expiration=datetime.now().timestamp() + 1000, ), True, ), ( - FlowState( - tag=FlowStateTag.BEGIN, + _FlowState( + tag=_FlowStateTag.BEGIN, attempts_remaining=0, expiration=datetime.now().timestamp() + 1000, ), False, ), ( - FlowState( - tag=FlowStateTag.COMPLETE, + _FlowState( + tag=_FlowStateTag.COMPLETE, attempts_remaining=-1, expiration=datetime.now().timestamp() + 1000, ), False, ), ( - FlowState( - tag=FlowStateTag.FAILURE, + _FlowState( + tag=_FlowStateTag.FAILURE, attempts_remaining=1, expiration=datetime.now().timestamp() + 1000, ), False, ), ( - FlowState( - tag=FlowStateTag.CONTINUE, + _FlowState( + tag=_FlowStateTag.CONTINUE, attempts_remaining=1, expiration=datetime.now().timestamp() + 1000, ), diff --git a/tests/hosting_core/oauth/test_flow_storage_client.py b/tests/hosting_core/_oauth/test_flow_storage_client.py similarity index 79% rename from tests/hosting_core/oauth/test_flow_storage_client.py rename to tests/hosting_core/_oauth/test_flow_storage_client.py index efad76b0..c1710de1 100644 --- a/tests/hosting_core/oauth/test_flow_storage_client.py +++ b/tests/hosting_core/_oauth/test_flow_storage_client.py @@ -1,7 +1,7 @@ import pytest from microsoft_agents.hosting.core.storage import MemoryStorage -from microsoft_agents.hosting.core.oauth import FlowState, FlowStorageClient +from microsoft_agents.hosting.core._oauth import _FlowState, _FlowStorageClient from tests._common.storage.utils import MockStoreItem from tests._common.data import TEST_DEFAULTS @@ -16,7 +16,7 @@ def storage(self): @pytest.fixture def client(self, storage): - return FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) + return _FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -28,18 +28,18 @@ def client(self, storage): ], ) async def test_init_base_key(self, mocker, channel_id, user_id): - client = FlowStorageClient(channel_id, user_id, mocker.Mock()) + client = _FlowStorageClient(channel_id, user_id, mocker.Mock()) assert client.base_key == f"auth/{channel_id}/{user_id}/" @pytest.mark.asyncio async def test_init_fails_without_user_id(self, storage): with pytest.raises(ValueError): - FlowStorageClient(DEFAULTS.channel_id, "", storage) + _FlowStorageClient(DEFAULTS.channel_id, "", storage) @pytest.mark.asyncio async def test_init_fails_without_channel_id(self, storage): with pytest.raises(ValueError): - FlowStorageClient("", DEFAULTS.user_id, storage) + _FlowStorageClient("", DEFAULTS.user_id, storage) @pytest.mark.parametrize( "auth_handler_id, expected", @@ -56,23 +56,23 @@ def test_key(self, client, auth_handler_id, expected): async def test_read(self, mocker, auth_handler_id): storage = mocker.AsyncMock() key = f"auth/{DEFAULTS.channel_id}/{DEFAULTS.user_id}/{auth_handler_id}" - storage.read.return_value = {key: FlowState()} - client = FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) + storage.read.return_value = {key: _FlowState()} + client = _FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) res = await client.read(auth_handler_id) assert res is storage.read.return_value[key] storage.read.assert_called_once_with( - [client.key(auth_handler_id)], target_cls=FlowState + [client.key(auth_handler_id)], target_cls=_FlowState ) @pytest.mark.asyncio async def test_read_missing(self, mocker): storage = mocker.AsyncMock() storage.read.return_value = {} - client = FlowStorageClient("__channel_id", "__user_id", storage) + client = _FlowStorageClient("__channel_id", "__user_id", storage) res = await client.read("non_existent_handler") assert res is None storage.read.assert_called_once_with( - [client.key("non_existent_handler")], target_cls=FlowState + [client.key("non_existent_handler")], target_cls=_FlowState ) @pytest.mark.asyncio @@ -80,8 +80,8 @@ async def test_read_missing(self, mocker): async def test_write(self, mocker, auth_handler_id): storage = mocker.AsyncMock() storage.write.return_value = None - client = FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) - flow_state = mocker.Mock(spec=FlowState) + client = _FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) + flow_state = mocker.Mock(spec=_FlowState) flow_state.auth_handler_id = auth_handler_id await client.write(flow_state) storage.write.assert_called_once_with({client.key(auth_handler_id): flow_state}) @@ -91,15 +91,15 @@ async def test_write(self, mocker, auth_handler_id): async def test_delete(self, mocker, auth_handler_id): storage = mocker.AsyncMock() storage.delete.return_value = None - client = FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) + client = _FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) await client.delete(auth_handler_id) storage.delete.assert_called_once_with([client.key(auth_handler_id)]) @pytest.mark.asyncio async def test_integration_with_memory_storage(self): - flow_state_alpha = FlowState(auth_handler_id="handler") - flow_state_beta = FlowState(auth_handler_id="auth_handler", user_token="token") + flow_state_alpha = _FlowState(auth_handler_id="handler") + flow_state_beta = _FlowState(auth_handler_id="auth_handler") storage = MemoryStorage( { @@ -130,10 +130,10 @@ async def delete_both(*args, **kwargs): await storage.delete(*args, **kwargs) await baseline.delete(*args, **kwargs) - client = FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) + client = _FlowStorageClient(DEFAULTS.channel_id, DEFAULTS.user_id, storage) - new_flow_state_alpha = FlowState(auth_handler_id="handler") - flow_state_chi = FlowState(auth_handler_id="chi") + new_flow_state_alpha = _FlowState(auth_handler_id="handler") + flow_state_chi = _FlowState(auth_handler_id="chi") await client.write(new_flow_state_alpha) await client.write(flow_state_chi) @@ -164,14 +164,15 @@ async def delete_both(*args, **kwargs): await read_check( [f"auth/{DEFAULTS.channel_id}/{DEFAULTS.user_id}/handler"], - target_cls=FlowState, + target_cls=_FlowState, ) await read_check( [f"auth/{DEFAULTS.channel_id}/{DEFAULTS.user_id}/auth_handler"], - target_cls=FlowState, + target_cls=_FlowState, ) await read_check( - [f"auth/{DEFAULTS.channel_id}/{DEFAULTS.user_id}/chi"], target_cls=FlowState + [f"auth/{DEFAULTS.channel_id}/{DEFAULTS.user_id}/chi"], + target_cls=_FlowState, ) await read_check(["other_data"], target_cls=MockStoreItem) await read_check(["some_data"], target_cls=MockStoreItem) diff --git a/tests/hosting_core/oauth/test_oauth_flow.py b/tests/hosting_core/_oauth/test_oauth_flow.py similarity index 82% rename from tests/hosting_core/oauth/test_oauth_flow.py rename to tests/hosting_core/_oauth/test_oauth_flow.py index 62b75b53..129540be 100644 --- a/tests/hosting_core/oauth/test_oauth_flow.py +++ b/tests/hosting_core/_oauth/test_oauth_flow.py @@ -9,11 +9,11 @@ TokenExchangeState, ConversationReference, ) -from microsoft_agents.hosting.core.oauth import ( - OAuthFlow, - FlowErrorTag, - FlowStateTag, - FlowResponse, +from microsoft_agents.hosting.core._oauth import ( + _OAuthFlow, + _FlowErrorTag, + _FlowStateTag, + _FlowResponse, ) from tests._common.data import TEST_DEFAULTS, TEST_FLOW_DATA @@ -65,13 +65,13 @@ def activity(self, mocker): @pytest.fixture def flow(self, flow_state, user_token_client): - return OAuthFlow(flow_state, user_token_client) + return _OAuthFlow(flow_state, user_token_client) class TestOAuthFlow(TestUtils): def test_init_no_user_token_client(self, flow_state): with pytest.raises(ValueError): - OAuthFlow(flow_state, None) + _OAuthFlow(flow_state, None) @pytest.mark.parametrize( "missing_value", ["connection", "ms_app_id", "channel_id", "user_id"] @@ -81,13 +81,13 @@ def test_init_errors(self, missing_value, user_token_client): flow_state = started_flow_state flow_state.__setattr__(missing_value, None) with pytest.raises(ValueError): - OAuthFlow(flow_state, user_token_client) + _OAuthFlow(flow_state, user_token_client) flow_state.__setattr__(missing_value, "") with pytest.raises(ValueError): - OAuthFlow(flow_state, user_token_client) + _OAuthFlow(flow_state, user_token_client) def test_init_with_state(self, flow_state, user_token_client): - flow = OAuthFlow(flow_state, user_token_client) + flow = _OAuthFlow(flow_state, user_token_client) assert flow.flow_state == flow_state def test_flow_state_prop_copy(self, flow): @@ -99,10 +99,9 @@ def test_flow_state_prop_copy(self, flow): @pytest.mark.asyncio async def test_get_user_token_success(self, flow_state, user_token_client): # setup - flow = OAuthFlow(flow_state, user_token_client) + flow = _OAuthFlow(flow_state, user_token_client) expected_final_flow_state = flow_state - expected_final_flow_state.user_token = DEFAULTS.token - expected_final_flow_state.tag = FlowStateTag.COMPLETE + expected_final_flow_state.tag = _FlowStateTag.COMPLETE # test token_response = await flow.get_user_token() @@ -125,7 +124,7 @@ async def test_get_user_token_failure(self, mocker, flow_state): user_token_client = self.UserTokenClient( mocker, get_token_return=TokenResponse() ) - flow = OAuthFlow(flow_state, user_token_client) + flow = _OAuthFlow(flow_state, user_token_client) expected_final_flow_state = flow.flow_state # test @@ -144,10 +143,9 @@ async def test_get_user_token_failure(self, mocker, flow_state): @pytest.mark.asyncio async def test_sign_out(self, flow_state, user_token_client): # setup - flow = OAuthFlow(flow_state, user_token_client) + flow = _OAuthFlow(flow_state, user_token_client) expected_flow_state = flow_state - expected_flow_state.user_token = "" - expected_flow_state.tag = FlowStateTag.NOT_STARTED + expected_flow_state.tag = _FlowStateTag.NOT_STARTED # test await flow.sign_out() @@ -166,10 +164,9 @@ async def test_begin_flow_easy_case(self, mocker, flow_state, activity): user_token_client = self.UserTokenClient( mocker, get_token_return=TokenResponse(token=DEFAULTS.token) ) - flow = OAuthFlow(flow_state, user_token_client) + flow = _OAuthFlow(flow_state, user_token_client) expected_flow_state = flow_state - expected_flow_state.user_token = DEFAULTS.token - expected_flow_state.tag = FlowStateTag.COMPLETE + expected_flow_state.tag = _FlowStateTag.COMPLETE # test response = await flow.begin_flow(activity) @@ -181,7 +178,7 @@ async def test_begin_flow_easy_case(self, mocker, flow_state, activity): assert response.flow_state == out_flow_state assert response.sign_in_resource is None # No sign-in resource in this case - assert response.flow_error_tag == FlowErrorTag.NONE + assert response.flow_error_tag == _FlowErrorTag.NONE assert response.token_response assert response.token_response.token == DEFAULTS.token user_token_client.user_token.get_token.assert_called_once_with( @@ -207,10 +204,9 @@ async def test_begin_flow_long_case(self, mocker, flow_state, activity): ) # setup - flow = OAuthFlow(flow_state, user_token_client) + flow = _OAuthFlow(flow_state, user_token_client) expected_flow_state = flow_state - expected_flow_state.user_token = "" - expected_flow_state.tag = FlowStateTag.BEGIN + expected_flow_state.tag = _FlowStateTag.BEGIN expected_flow_state.attempts_remaining = 3 expected_flow_state.continuation_activity = activity @@ -225,9 +221,9 @@ async def test_begin_flow_long_case(self, mocker, flow_state, activity): assert out_flow_state == response.flow_state assert out_flow_state == expected_flow_state - # verify FlowResponse + # verify _FlowResponse assert response.sign_in_resource == dummy_sign_in_resource - assert response.flow_error_tag == FlowErrorTag.NONE + assert response.flow_error_tag == _FlowErrorTag.NONE assert not response.token_response # robrandao: TODO more assertions on sign_in_resource @@ -236,9 +232,9 @@ async def test_continue_flow_not_active( self, inactive_flow_state, user_token_client, activity ): # setup - flow = OAuthFlow(inactive_flow_state, user_token_client) + flow = _OAuthFlow(inactive_flow_state, user_token_client) expected_flow_state = inactive_flow_state - expected_flow_state.tag = FlowStateTag.FAILURE + expected_flow_state.tag = _FlowStateTag.FAILURE # test flow_response = await flow.continue_flow(activity) @@ -253,12 +249,12 @@ async def helper_continue_flow_failure( self, active_flow_state, user_token_client, activity, flow_error_tag ): # setup - flow = OAuthFlow(active_flow_state, user_token_client) + flow = _OAuthFlow(active_flow_state, user_token_client) expected_flow_state = active_flow_state expected_flow_state.tag = ( - FlowStateTag.CONTINUE + _FlowStateTag.CONTINUE if active_flow_state.attempts_remaining > 1 - else FlowStateTag.FAILURE + else _FlowStateTag.FAILURE ) expected_flow_state.attempts_remaining = ( active_flow_state.attempts_remaining - 1 @@ -278,10 +274,9 @@ async def helper_continue_flow_success( self, active_flow_state, user_token_client, activity, expected_token ): # setup - flow = OAuthFlow(active_flow_state, user_token_client) + flow = _OAuthFlow(active_flow_state, user_token_client) expected_flow_state = active_flow_state - expected_flow_state.tag = FlowStateTag.COMPLETE - expected_flow_state.user_token = DEFAULTS.token + expected_flow_state.tag = _FlowStateTag.COMPLETE expected_flow_state.attempts_remaining = active_flow_state.attempts_remaining # test @@ -295,7 +290,7 @@ async def helper_continue_flow_success( assert flow_response.flow_state == out_flow_state assert expected_flow_state == out_flow_state assert flow_response.token_response == TokenResponse(token=expected_token) - assert flow_response.flow_error_tag == FlowErrorTag.NONE + assert flow_response.flow_error_tag == _FlowErrorTag.NONE @pytest.mark.asyncio @pytest.mark.parametrize("magic_code", ["magic", "123", "", "1239453"]) @@ -308,7 +303,7 @@ async def test_continue_flow_active_message_magic_format_error( active_flow_state, user_token_client, activity, - FlowErrorTag.MAGIC_FORMAT, + _FlowErrorTag.MAGIC_FORMAT, ) user_token_client.user_token.get_token.assert_not_called() @@ -325,7 +320,7 @@ async def test_continue_flow_active_message_magic_code_error( active_flow_state, user_token_client, activity, - FlowErrorTag.MAGIC_CODE_INCORRECT, + _FlowErrorTag.MAGIC_CODE_INCORRECT, ) user_token_client.user_token.get_token.assert_called_once_with( user_id=active_flow_state.user_id, @@ -371,7 +366,7 @@ async def test_continue_flow_active_sign_in_verify_state_error( value={"state": "magic_code"}, ) await self.helper_continue_flow_failure( - active_flow_state, user_token_client, activity, FlowErrorTag.OTHER + active_flow_state, user_token_client, activity, _FlowErrorTag.OTHER ) user_token_client.user_token.get_token.assert_called_once_with( user_id=active_flow_state.user_id, @@ -423,7 +418,7 @@ async def test_continue_flow_active_sign_in_token_exchange_error( value=token_exchange_request, ) await self.helper_continue_flow_failure( - active_flow_state, user_token_client, activity, FlowErrorTag.OTHER + active_flow_state, user_token_client, activity, _FlowErrorTag.OTHER ) user_token_client.user_token.exchange_token.assert_called_once_with( user_id=active_flow_state.user_id, @@ -467,7 +462,7 @@ async def test_continue_flow_invalid_invoke_name( activity = self.Activity( mocker, type=ActivityTypes.invoke, name="other", value={} ) - flow = OAuthFlow(active_flow_state, user_token_client) + flow = _OAuthFlow(active_flow_state, user_token_client) await flow.continue_flow(activity) @pytest.mark.asyncio @@ -478,7 +473,7 @@ async def test_continue_flow_invalid_activity_type( activity = self.Activity( mocker, type=ActivityTypes.command, name="other", value={} ) - flow = OAuthFlow(active_flow_state, user_token_client) + flow = _OAuthFlow(active_flow_state, user_token_client) await flow.continue_flow(activity) @pytest.mark.asyncio @@ -489,62 +484,62 @@ async def test_begin_or_continue_flow_not_started_flow( ): # setup not_started_flow_state = FLOW_DATA.not_started.model_copy() - expected_response = FlowResponse( + expected_response = _FlowResponse( flow_state=not_started_flow_state, - token_response=TokenResponse(token=not_started_flow_state.user_token), + token_response=TokenResponse(), ) - mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + mocker.patch.object(_OAuthFlow, "begin_flow", return_value=expected_response) - flow = OAuthFlow(not_started_flow_state, mocker.Mock()) + flow = _OAuthFlow(not_started_flow_state, mocker.Mock()) # test actual_response = await flow.begin_or_continue_flow(activity) # verify assert actual_response is expected_response - OAuthFlow.begin_flow.assert_called_once_with(activity) + _OAuthFlow.begin_flow.assert_called_once_with(activity) @pytest.mark.asyncio async def test_begin_or_continue_flow_inactive_flow( self, mocker, inactive_flow_state_not_completed, activity ): # mock - expected_response = FlowResponse( + expected_response = _FlowResponse( flow_state=inactive_flow_state_not_completed, token_response=TokenResponse(), ) - mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + mocker.patch.object(_OAuthFlow, "begin_flow", return_value=expected_response) # setup - flow = OAuthFlow(inactive_flow_state_not_completed, mocker.Mock()) + flow = _OAuthFlow(inactive_flow_state_not_completed, mocker.Mock()) # test actual_response = await flow.begin_or_continue_flow(activity) # verify assert actual_response is expected_response - OAuthFlow.begin_flow.assert_called_once_with(activity) + _OAuthFlow.begin_flow.assert_called_once_with(activity) @pytest.mark.asyncio async def test_begin_or_continue_flow_active_flow( self, mocker, active_flow_state, activity, user_token_client ): # mock - expected_response = FlowResponse( + expected_response = _FlowResponse( flow_state=active_flow_state, - token_response=TokenResponse(token=active_flow_state.user_token), + token_response=TokenResponse(token=DEFAULTS.token), ) - mocker.patch.object(OAuthFlow, "continue_flow", return_value=expected_response) + mocker.patch.object(_OAuthFlow, "continue_flow", return_value=expected_response) # setup - flow = OAuthFlow(active_flow_state, user_token_client) + flow = _OAuthFlow(active_flow_state, user_token_client) # test actual_response = await flow.begin_or_continue_flow(activity) # verify assert actual_response is expected_response - OAuthFlow.continue_flow.assert_called_once_with(activity) + _OAuthFlow.continue_flow.assert_called_once_with(activity) @pytest.mark.asyncio async def test_begin_or_continue_flow_stale_flow_state( @@ -554,37 +549,37 @@ async def test_begin_or_continue_flow_stale_flow_state( ): # mock expired_flow_state = FLOW_DATA.active_exp.model_copy() - expected_response = FlowResponse() - mocker.patch.object(OAuthFlow, "begin_flow", return_value=expected_response) + expected_response = _FlowResponse() + mocker.patch.object(_OAuthFlow, "begin_flow", return_value=expected_response) # setup - flow = OAuthFlow(expired_flow_state, mocker.Mock()) + flow = _OAuthFlow(expired_flow_state, mocker.Mock()) # test actual_response = await flow.begin_or_continue_flow(activity) # verify assert actual_response is expected_response - OAuthFlow.begin_flow.assert_called_once_with(activity) + _OAuthFlow.begin_flow.assert_called_once_with(activity) @pytest.mark.asyncio async def test_begin_or_continue_flow_completed_flow_state(self, mocker, activity): completed_flow_state = FLOW_DATA.completed.model_copy() # mock - mocker.patch.object(OAuthFlow, "begin_flow", return_value=None) - mocker.patch.object(OAuthFlow, "continue_flow", return_value=None) - - # setup - expected_response = FlowResponse( + expected_response = _FlowResponse( flow_state=completed_flow_state, - token_response=TokenResponse(token=completed_flow_state.user_token), + token_response=TokenResponse(token="some-token"), ) - flow = OAuthFlow(completed_flow_state, mocker.Mock()) + mocker.patch.object(_OAuthFlow, "begin_flow", return_value=expected_response) + mocker.patch.object(_OAuthFlow, "continue_flow", return_value=None) + + # setup + flow = _OAuthFlow(completed_flow_state, mocker.Mock()) # test actual_response = await flow.begin_or_continue_flow(activity) # verify assert actual_response == expected_response - OAuthFlow.begin_flow.assert_not_called() - OAuthFlow.continue_flow.assert_not_called() + _OAuthFlow.begin_flow.assert_called_once() + _OAuthFlow.continue_flow.assert_not_called() diff --git a/tests/hosting_core/oauth/__init__.py b/tests/hosting_core/app/oauth/__init__.py similarity index 100% rename from tests/hosting_core/oauth/__init__.py rename to tests/hosting_core/app/oauth/__init__.py diff --git a/tests/hosting_core/app/oauth/_common.py b/tests/hosting_core/app/oauth/_common.py new file mode 100644 index 00000000..81247cb8 --- /dev/null +++ b/tests/hosting_core/app/oauth/_common.py @@ -0,0 +1,76 @@ +from microsoft_agents.activity import Activity, ActivityTypes + +from microsoft_agents.hosting.core import TurnContext + +from tests._common.data import TEST_DEFAULTS +from tests._common.testing_objects import mock_UserTokenClient + +DEFAULTS = TEST_DEFAULTS() + + +def testing_Activity(): + return Activity( + type=ActivityTypes.message, + channel_id=DEFAULTS.channel_id, + from_property={"id": DEFAULTS.user_id}, + text="Hello, World!", + ) + + +def testing_TurnContext( + mocker, + channel_id=DEFAULTS.channel_id, + user_id=DEFAULTS.user_id, + user_token_client=None, + activity=None, +): + if not user_token_client: + user_token_client = mock_UserTokenClient(mocker) + + turn_context = mocker.Mock() + if not activity: + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = user_id + turn_context.activity.type = ActivityTypes.message + else: + turn_context.activity = activity + turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" + turn_context.adapter.AGENT_IDENTITY_KEY = "__agent_identity_key" + agent_identity = mocker.Mock() + agent_identity.claims = {"aud": DEFAULTS.ms_app_id} + turn_context.turn_state = { + "__user_token_client": user_token_client, + "__agent_identity_key": agent_identity, + } + return turn_context + + +def testing_TurnContext_magic( + mocker, + channel_id=DEFAULTS.channel_id, + user_id=DEFAULTS.user_id, + user_token_client=None, + activity=None, +): + if not user_token_client: + user_token_client = mock_UserTokenClient(mocker) + + turn_context = mocker.MagicMock(spec=TurnContext) + turn_context.adapter = mocker.Mock() + if not activity: + turn_context.activity = mocker.Mock() + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = user_id + turn_context.activity.type = ActivityTypes.message + else: + turn_context.activity = activity + turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" + turn_context.adapter.AGENT_IDENTITY_KEY = "__agent_identity_key" + agent_identity = mocker.Mock() + agent_identity.claims = {"aud": DEFAULTS.ms_app_id} + turn_context.turn_state = mocker.Mock() + turn_context.turn_state = { + "__user_token_client": user_token_client, + "__agent_identity_key": agent_identity, + } + return turn_context diff --git a/tests/hosting_core/app/oauth/_env.py b/tests/hosting_core/app/oauth/_env.py new file mode 100644 index 00000000..160373d3 --- /dev/null +++ b/tests/hosting_core/app/oauth/_env.py @@ -0,0 +1,17 @@ +from tests._common.data import TEST_DEFAULTS + +DEFAULTS = TEST_DEFAULTS() + + +def ENV_CONFIG(): + return { + "AGENTAPPLICATION": { + "USERAUTHORIZATION": { + "HANDLERS": { + DEFAULTS.connection_name: { + "SETTINGS": {AZUREBOTOAUTHCONNECTIONNAME} + } + } + } + } + } diff --git a/tests/_common/testing_objects/http/mock_abs_api.py b/tests/hosting_core/app/oauth/_handlers/__init__.py similarity index 100% rename from tests/_common/testing_objects/http/mock_abs_api.py rename to tests/hosting_core/app/oauth/_handlers/__init__.py diff --git a/tests/hosting_core/app/oauth/_handlers/test_agentic_user_authorization.py b/tests/hosting_core/app/oauth/_handlers/test_agentic_user_authorization.py new file mode 100644 index 00000000..87f8ba27 --- /dev/null +++ b/tests/hosting_core/app/oauth/_handlers/test_agentic_user_authorization.py @@ -0,0 +1,373 @@ +from math import exp +import pytest + +from microsoft_agents.activity import ( + Activity, + ChannelAccount, + RoleTypes, + TokenResponse, +) + +from microsoft_agents.authentication.msal import MsalAuth, MsalConnectionManager + +from microsoft_agents.hosting.core.app.oauth import AgenticUserAuthorization +from microsoft_agents.hosting.core.storage import MemoryStorage +from microsoft_agents.hosting.core._oauth import _FlowStateTag + +from tests._common.data import TEST_DEFAULTS, TEST_AGENTIC_ENV_DICT +from tests._common.mock_utils import mock_class + +from .._common import ( + testing_TurnContext_magic, +) + +DEFAULTS = TEST_DEFAULTS() +AGENTIC_ENV_DICT = TEST_AGENTIC_ENV_DICT() + + +class TestUtils: + def setup_method(self, mocker): + self.TurnContext = testing_TurnContext_magic + + @pytest.fixture + def storage(self): + return MemoryStorage() + + @pytest.fixture + def connection_manager(self, mocker): + return MsalConnectionManager(**AGENTIC_ENV_DICT) + + @pytest.fixture + def auth_handler_settings(self): + return AGENTIC_ENV_DICT["AGENTAPPLICATION"]["USERAUTHORIZATION"]["HANDLERS"][ + DEFAULTS.agentic_auth_handler_id + ]["SETTINGS"] + + @pytest.fixture + def agentic_auth(self, storage, connection_manager, auth_handler_settings): + return AgenticUserAuthorization( + storage, + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + + @pytest.fixture(params=[RoleTypes.user, RoleTypes.skill, RoleTypes.agent]) + def non_agentic_role(self, request): + return request.param + + @pytest.fixture(params=[RoleTypes.agentic_user, RoleTypes.agentic_identity]) + def agentic_role(self, request): + return request.param + + def mock_provider( + self, mocker, app_token="bot_token", instance_token=None, user_token=None + ): + mock_provider = mocker.Mock(spec=MsalAuth) + mock_provider.get_agentic_instance_token = mocker.AsyncMock( + return_value=[instance_token, app_token] + ) + mock_provider.get_agentic_user_token = mocker.AsyncMock(return_value=user_token) + return mock_provider + + def mock_class_provider( + self, mocker, app_token="bot_token", instance_token=None, user_token=None + ): + instance = self.mock_provider(mocker, app_token, instance_token, user_token) + mock_class(mocker, MsalAuth, instance) + + +class TestAgenticUserAuthorization(TestUtils): + + @pytest.mark.asyncio + async def test_get_agentic_instance_token_not_agentic( + self, mocker, non_agentic_role, agentic_auth + ): + activity = Activity( + type="message", + recipient=ChannelAccount( + id=DEFAULTS.agentic_user_id, + agentic_app_id=DEFAULTS.agentic_instance_id, + role=non_agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + assert await agentic_auth.get_agentic_instance_token(context) == TokenResponse() + + @pytest.mark.asyncio + async def test_get_agentic_user_token_not_agentic( + self, mocker, non_agentic_role, agentic_auth + ): + activity = Activity( + type="message", + recipient=ChannelAccount( + id=DEFAULTS.agentic_user_id, + agentic_app_id=DEFAULTS.agentic_instance_id, + role=non_agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + assert ( + await agentic_auth.get_agentic_user_token(context, ["user.Read"]) + == TokenResponse() + ) + + @pytest.mark.asyncio + async def test_get_agentic_user_token_agentic_no_user_id( + self, mocker, agentic_role, agentic_auth + ): + activity = Activity( + type="message", + recipient=ChannelAccount( + agentic_app_id=DEFAULTS.agentic_instance_id, role=agentic_role + ), + ) + context = self.TurnContext(mocker, activity=activity) + assert ( + await agentic_auth.get_agentic_user_token(context, ["user.Read"]) + == TokenResponse() + ) + + @pytest.mark.asyncio + async def test_get_agentic_instance_token_is_agentic( + self, mocker, agentic_role, agentic_auth, auth_handler_settings + ): + mock_provider = self.mock_provider(mocker, instance_token=DEFAULTS.token) + connection_manager = mocker.Mock(spec=MsalConnectionManager) + connection_manager.get_token_provider = mocker.Mock(return_value=mock_provider) + + agentic_auth = AgenticUserAuthorization( + MemoryStorage(), + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + + token = await agentic_auth.get_agentic_instance_token(context) + assert token == TokenResponse(token=DEFAULTS.token) + mock_provider.get_agentic_instance_token.assert_called_once_with( + DEFAULTS.agentic_instance_id + ) + + @pytest.mark.asyncio + async def test_get_agentic_user_token_is_agentic( + self, mocker, agentic_role, agentic_auth, auth_handler_settings + ): + mock_provider = self.mock_provider(mocker, user_token=DEFAULTS.token) + + connection_manager = mocker.Mock(spec=MsalConnectionManager) + connection_manager.get_token_provider = mocker.Mock(return_value=mock_provider) + + agentic_auth = AgenticUserAuthorization( + MemoryStorage(), + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + + token = await agentic_auth.get_agentic_user_token(context, ["user.Read"]) + assert token == TokenResponse(token=DEFAULTS.token) + mock_provider.get_agentic_user_token.assert_called_once_with( + DEFAULTS.agentic_instance_id, "some_id", ["user.Read"] + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "scopes_list, expected_scopes_list", + [ + (["user.Read"], ["user.Read"]), + ([], ["user.Read", "Mail.Read"]), + (None, ["user.Read", "Mail.Read"]), + ], + ) + async def test_sign_in_success( + self, + mocker, + scopes_list, + agentic_role, + expected_scopes_list, + auth_handler_settings, + ): + mock_provider = self.mock_provider(mocker, user_token="my_token") + + connection_manager = mocker.Mock(spec=MsalConnectionManager) + connection_manager.get_token_provider = mocker.Mock(return_value=mock_provider) + + agentic_auth = AgenticUserAuthorization( + MemoryStorage(), + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + res = await agentic_auth._sign_in(context, "my_connection", scopes_list) + assert res.token_response.token == "my_token" + assert res.tag == _FlowStateTag.COMPLETE + + mock_provider.get_agentic_user_token.assert_called_once_with( + DEFAULTS.agentic_instance_id, "some_id", expected_scopes_list + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "scopes_list, expected_scopes_list", + [ + (["user.Read"], ["user.Read"]), + ([], ["user.Read", "Mail.Read"]), + (None, ["user.Read", "Mail.Read"]), + ], + ) + async def test_sign_in_failure( + self, + mocker, + scopes_list, + agentic_role, + expected_scopes_list, + auth_handler_settings, + ): + mock_provider = self.mock_provider(mocker, user_token=None) + + connection_manager = mocker.Mock(spec=MsalConnectionManager) + connection_manager.get_token_provider = mocker.Mock(return_value=mock_provider) + + agentic_auth = AgenticUserAuthorization( + MemoryStorage(), + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + res = await agentic_auth._sign_in(context, "my_connection", scopes_list) + assert not res.token_response + assert res.tag == _FlowStateTag.FAILURE + + mock_provider.get_agentic_user_token.assert_called_once_with( + DEFAULTS.agentic_instance_id, "some_id", expected_scopes_list + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "scopes_list, expected_scopes_list", + [ + (["user.Read"], ["user.Read"]), + ([], ["user.Read", "Mail.Read"]), + (None, ["user.Read", "Mail.Read"]), + ], + ) + async def test_get_refreshed_token_success( + self, + mocker, + scopes_list, + agentic_role, + expected_scopes_list, + auth_handler_settings, + ): + mock_provider = self.mock_provider(mocker, user_token="my_token") + + connection_manager = mocker.Mock(spec=MsalConnectionManager) + connection_manager.get_token_provider = mocker.Mock(return_value=mock_provider) + + agentic_auth = AgenticUserAuthorization( + MemoryStorage(), + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + res = await agentic_auth.get_refreshed_token( + context, "my_connection", scopes_list + ) + assert res == TokenResponse(token="my_token") + + mock_provider.get_agentic_user_token.assert_called_once_with( + DEFAULTS.agentic_instance_id, "some_id", expected_scopes_list + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "scopes_list, expected_scopes_list", + [ + (["user.Read"], ["user.Read"]), + ([], ["user.Read", "Mail.Read"]), + (None, ["user.Read", "Mail.Read"]), + ], + ) + async def test_get_refreshed_token_failure( + self, + mocker, + scopes_list, + agentic_role, + expected_scopes_list, + auth_handler_settings, + ): + mock_provider = self.mock_provider(mocker, user_token=None) + + connection_manager = mocker.Mock(spec=MsalConnectionManager) + connection_manager.get_token_provider = mocker.Mock(return_value=mock_provider) + + agentic_auth = AgenticUserAuthorization( + MemoryStorage(), + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.agentic_auth_handler_id, + ) + activity = Activity( + type="message", + recipient=ChannelAccount( + id="some_id", + agentic_app_id=DEFAULTS.agentic_instance_id, + role=agentic_role, + ), + ) + context = self.TurnContext(mocker, activity=activity) + res = await agentic_auth.get_refreshed_token( + context, "my_connection", scopes_list + ) + assert res == TokenResponse() + mock_provider.get_agentic_user_token.assert_called_once_with( + DEFAULTS.agentic_instance_id, "some_id", expected_scopes_list + ) diff --git a/tests/hosting_core/app/oauth/_handlers/test_user_authorization.py b/tests/hosting_core/app/oauth/_handlers/test_user_authorization.py new file mode 100644 index 00000000..f2764125 --- /dev/null +++ b/tests/hosting_core/app/oauth/_handlers/test_user_authorization.py @@ -0,0 +1,367 @@ +import pytest +import jwt + +from microsoft_agents.activity import ActivityTypes, TokenResponse + +from microsoft_agents.authentication.msal import MsalAuth, MsalConnectionManager + +from microsoft_agents.hosting.core import MemoryStorage +from microsoft_agents.hosting.core.app.oauth import _UserAuthorization, _SignInResponse +from microsoft_agents.hosting.core._oauth import ( + _FlowStorageClient, + _FlowStateTag, + _FlowState, + _FlowResponse, + _OAuthFlow, +) + +# test constants +from tests._common.data import ( + TEST_FLOW_DATA, + TEST_AUTH_DATA, + TEST_STORAGE_DATA, + TEST_DEFAULTS, + TEST_AGENTIC_ENV_DICT, +) +from tests._common.mock_utils import mock_instance +from tests._common.fixtures import FlowStateFixtures +from tests._common.testing_objects import ( + mock_class_OAuthFlow, + mock_UserTokenClient, +) +from tests.hosting_core._common import flow_state_eq + +DEFAULTS = TEST_DEFAULTS() +FLOW_DATA = TEST_FLOW_DATA() +STORAGE_DATA = TEST_STORAGE_DATA() +AGENTIC_ENV_DICT = TEST_AGENTIC_ENV_DICT() + + +def make_jwt(token: str = DEFAULTS.token, aud="api://default"): + if aud: + return jwt.encode({"aud": aud}, token, algorithm="HS256") + else: + return jwt.encode({}, token, algorithm="HS256") + + +class MyUserAuthorization(_UserAuthorization): + async def _handle_flow_response(self, *args, **kwargs): + pass + + +def testing_TurnContext( + mocker, + channel_id=DEFAULTS.channel_id, + user_id=DEFAULTS.user_id, + user_token_client=None, +): + if not user_token_client: + user_token_client = mock_UserTokenClient(mocker) + + turn_context = mocker.Mock() + turn_context.activity.channel_id = channel_id + turn_context.activity.from_property.id = user_id + turn_context.activity.type = ActivityTypes.message + turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" + turn_context.adapter.AGENT_IDENTITY_KEY = "__agent_identity_key" + agent_identity = mocker.Mock() + agent_identity.claims = {"aud": DEFAULTS.ms_app_id} + turn_context.turn_state = { + "__user_token_client": user_token_client, + "__agent_identity_key": agent_identity, + } + return turn_context + + +async def read_state( + storage, + channel_id=DEFAULTS.channel_id, + user_id=DEFAULTS.user_id, + auth_handler_id=DEFAULTS.auth_handler_id, +): + storage_client = _FlowStorageClient(channel_id, user_id, storage) + key = storage_client.key(auth_handler_id) + return (await storage.read([key], target_cls=_FlowState)).get(key) + + +def mock_provider(mocker, exchange_token=None): + instance = mock_instance( + mocker, MsalAuth, {"acquire_token_on_behalf_of": exchange_token} + ) + mocker.patch.object(MsalConnectionManager, "get_connection", return_value=instance) + return instance + + +class TestEnv(FlowStateFixtures): + def setup_method(self): + self.TurnContext = testing_TurnContext + + @pytest.fixture + def context(self, mocker): + return self.TurnContext(mocker) + + @pytest.fixture + def storage(self): + return MemoryStorage(STORAGE_DATA.get_init_data()) + + @pytest.fixture + def connection_manager(self): + return MsalConnectionManager(**AGENTIC_ENV_DICT) + + @pytest.fixture + def auth_handlers(self): + return TEST_AUTH_DATA().auth_handlers + + @pytest.fixture + def auth_handler_settings(self): + return AGENTIC_ENV_DICT["AGENTAPPLICATION"]["USERAUTHORIZATION"]["HANDLERS"][ + DEFAULTS.auth_handler_id + ]["SETTINGS"] + + @pytest.fixture + def user_authorization(self, connection_manager, storage, auth_handler_settings): + return MyUserAuthorization( + storage, + connection_manager, + auth_handler_settings=auth_handler_settings, + auth_handler_id=DEFAULTS.auth_handler_id, + ) + + @pytest.fixture + def exchangeable_token(self): + jwt.encode({"aud": "exchange_audience"}, "secret", algorithm="HS256") + + @pytest.fixture( + params=[ + [None, ["scope1", "scope2"]], + [[], ["scope1", "scope2"]], + [["scope1"], ["scope1"]], + ] + ) + def scope_set(self, request): + return request.param + + @pytest.fixture( + params=[ + ["AGENTIC", "AGENTIC"], + [None, DEFAULTS.obo_connection_name], + ["", DEFAULTS.obo_connection_name], + ] + ) + def connection_set(self, request): + return request.param + + +class TestUserAuthorization(TestEnv): + + # TODO -> test init + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "flow_response, exchange_attempted, token_exchange_response, expected_response", + [ + [ + _FlowResponse( + token_response=TokenResponse(token=make_jwt()), + flow_state=_FlowState( + tag=_FlowStateTag.COMPLETE, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + True, + "wow", + _SignInResponse( + token_response=TokenResponse(token="wow"), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + _FlowResponse( + token_response=TokenResponse(token=make_jwt(aud=None)), + flow_state=_FlowState( + tag=_FlowStateTag.COMPLETE, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + False, + "wow", + _SignInResponse( + token_response=TokenResponse(token=make_jwt(aud=None)), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + _FlowResponse( + token_response=TokenResponse( + token=make_jwt(token="some_value", aud="other") + ), + flow_state=_FlowState( + tag=_FlowStateTag.COMPLETE, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + False, + DEFAULTS.token, + _SignInResponse( + token_response=TokenResponse( + token=make_jwt("some_value", aud="other") + ), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + _FlowResponse( + token_response=TokenResponse(token=make_jwt(token="some_value")), + flow_state=_FlowState( + tag=_FlowStateTag.COMPLETE, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + True, + None, + _SignInResponse(tag=_FlowStateTag.FAILURE), + ], + [ + _FlowResponse( + flow_state=_FlowState( + tag=_FlowStateTag.BEGIN, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + False, + None, + _SignInResponse(tag=_FlowStateTag.BEGIN), + ], + [ + _FlowResponse( + flow_state=_FlowState( + tag=_FlowStateTag.CONTINUE, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + False, + None, + _SignInResponse(tag=_FlowStateTag.CONTINUE), + ], + [ + _FlowResponse( + flow_state=_FlowState( + tag=_FlowStateTag.FAILURE, + auth_handler_id=DEFAULTS.auth_handler_id, + ), + ), + False, + None, + _SignInResponse(tag=_FlowStateTag.FAILURE), + ], + ], + ) + async def test_sign_in( + self, + mocker, + user_authorization, + context, + storage, + flow_response, + exchange_attempted, + token_exchange_response, + expected_response, + scope_set, + connection_set, + ): + request_scopes, expected_scopes = scope_set + request_connection, expected_connection = connection_set + mock_class_OAuthFlow(mocker, begin_or_continue_flow_return=flow_response) + provider = mock_provider(mocker, exchange_token=token_exchange_response) + + sign_in_response = await user_authorization._sign_in( + context, request_connection, request_scopes + ) + assert sign_in_response.token_response == expected_response.token_response + assert sign_in_response.tag == expected_response.tag + + state = await read_state(storage, auth_handler_id=DEFAULTS.auth_handler_id) + assert flow_state_eq(state, flow_response.flow_state) + if exchange_attempted: + MsalConnectionManager.get_connection.assert_called_once_with( + expected_connection + ) + provider.acquire_token_on_behalf_of.assert_called_once_with( + scopes=expected_scopes, + user_assertion=flow_response.token_response.token, + ) + + @pytest.mark.asyncio + async def test_sign_out_individual( + self, mocker, storage, user_authorization, context + ): + mock_class_OAuthFlow(mocker) + await user_authorization._sign_out(context) + assert ( + await read_state(storage, auth_handler_id=DEFAULTS.auth_handler_id) is None + ) + _OAuthFlow.sign_out.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "get_user_token_return, exchange_attempted, token_exchange_response, expected_response", + [ + [TokenResponse(token=make_jwt()), True, "wow", TokenResponse(token="wow")], + [ + TokenResponse(token=make_jwt(aud=None)), + False, + "wow", + TokenResponse(token=make_jwt(aud=None)), + ], + [ + TokenResponse(token=make_jwt(token="some_value", aud="other")), + False, + DEFAULTS.token, + TokenResponse(token=make_jwt("some_value", aud="other")), + ], + [ + TokenResponse(token=make_jwt(token="some_value")), + True, + None, + TokenResponse(), + ], + [TokenResponse(), False, None, TokenResponse()], + ], + ) + async def test_get_refreshed_token( + self, + mocker, + user_authorization, + context, + storage, + get_user_token_return, + exchange_attempted, + token_exchange_response, + expected_response, + scope_set, + connection_set, + ): + request_scopes, expected_scopes = scope_set + request_connection, expected_connection = connection_set + mock_class_OAuthFlow(mocker, get_user_token_return=get_user_token_return) + provider = mock_provider(mocker, exchange_token=token_exchange_response) + + state_before = await read_state( + storage, auth_handler_id=DEFAULTS.auth_handler_id + ) + token_response = await user_authorization.get_refreshed_token( + context, request_connection, request_scopes + ) + assert token_response == expected_response + + state = await read_state(storage, auth_handler_id=DEFAULTS.auth_handler_id) + + if state: + assert flow_state_eq(state, state_before) + if exchange_attempted: + MsalConnectionManager.get_connection.assert_called_once_with( + expected_connection + ) + provider.acquire_token_on_behalf_of.assert_called_once_with( + scopes=expected_scopes, user_assertion=get_user_token_return.token + ) diff --git a/tests/hosting_core/app/oauth/test_auth_handler.py b/tests/hosting_core/app/oauth/test_auth_handler.py new file mode 100644 index 00000000..ccaf15ec --- /dev/null +++ b/tests/hosting_core/app/oauth/test_auth_handler.py @@ -0,0 +1,47 @@ +import pytest + +from microsoft_agents.hosting.core import AuthHandler + +from tests._common.data import TEST_DEFAULTS, TEST_ENV_DICT, TEST_AGENTIC_ENV_DICT + +DEFAULTS = TEST_DEFAULTS() +ENV_DICT = TEST_ENV_DICT() +AGENTIC_ENV_DICT = TEST_AGENTIC_ENV_DICT() + + +class TestAuthHandler: + @pytest.fixture + def auth_setting(self): + return ENV_DICT["AGENTAPPLICATION"]["USERAUTHORIZATION"]["HANDLERS"][ + DEFAULTS.auth_handler_id + ]["SETTINGS"] + + @pytest.fixture + def agentic_auth_setting(self): + return AGENTIC_ENV_DICT["AGENTAPPLICATION"]["USERAUTHORIZATION"]["HANDLERS"][ + DEFAULTS.agentic_auth_handler_id + ]["SETTINGS"] + + def test_init(self, auth_setting): + auth_handler = AuthHandler(DEFAULTS.auth_handler_id, **auth_setting) + assert auth_handler.name == DEFAULTS.auth_handler_id + assert auth_handler.title == DEFAULTS.auth_handler_title + assert auth_handler.text == DEFAULTS.auth_handler_text + assert auth_handler.obo_connection_name == DEFAULTS.obo_connection_name + assert ( + auth_handler.abs_oauth_connection_name == DEFAULTS.abs_oauth_connection_name + ) + + def test_init_agentic(self, agentic_auth_setting): + auth_handler = AuthHandler( + DEFAULTS.agentic_auth_handler_id, **agentic_auth_setting + ) + assert auth_handler.name == DEFAULTS.agentic_auth_handler_id + assert auth_handler.title == DEFAULTS.agentic_auth_handler_title + assert auth_handler.text == DEFAULTS.agentic_auth_handler_text + assert auth_handler.obo_connection_name == DEFAULTS.agentic_obo_connection_name + assert auth_handler.scopes == ["user.Read", "Mail.Read"] + assert ( + auth_handler.abs_oauth_connection_name + == DEFAULTS.agentic_abs_oauth_connection_name + ) diff --git a/tests/hosting_core/app/oauth/test_authorization.py b/tests/hosting_core/app/oauth/test_authorization.py new file mode 100644 index 00000000..d3905fdf --- /dev/null +++ b/tests/hosting_core/app/oauth/test_authorization.py @@ -0,0 +1,722 @@ +from mimetypes import init +from re import A +import pytest +import jwt + +from typing import Optional + +from microsoft_agents.activity import Activity, ActivityTypes, TokenResponse + +from microsoft_agents.hosting.core.app.oauth import ( + _SignInResponse, + _SignInState, + _UserAuthorization, + Authorization, + AgenticUserAuthorization, +) + +from microsoft_agents.hosting.core._oauth import _FlowStateTag + +from microsoft_agents.hosting.core import ( + AuthHandler, + Storage, + MemoryStorage, + TurnContext, +) + +from tests._common.storage.utils import StorageBaseline + +# test constants +from tests._common.data import ( + TEST_FLOW_DATA, + TEST_AUTH_DATA, + TEST_STORAGE_DATA, + TEST_DEFAULTS, + TEST_ENV_DICT, + TEST_AGENTIC_ENV_DICT, +) +from tests._common.fixtures import FlowStateFixtures +from tests._common.testing_objects import ( + TestingConnectionManager as MockConnectionManager, + mock_UserTokenClient, + mock_class_UserAuthorization, + mock_class_AgenticUserAuthorization, + mock_class_Authorization, +) + +from ._common import testing_TurnContext, testing_Activity + +DEFAULTS = TEST_DEFAULTS() +FLOW_DATA = TEST_FLOW_DATA() +STORAGE_DATA = TEST_STORAGE_DATA() +ENV_DICT = TEST_ENV_DICT() +AGENTIC_ENV_DICT = TEST_AGENTIC_ENV_DICT() + + +def make_jwt(token: str = DEFAULTS.token, aud="api://default"): + if aud: + return jwt.encode({"aud": aud}, token, algorithm="HS256") + else: + return jwt.encode({}, token, algorithm="HS256") + + +def mock_variants(mocker, sign_in_return=None, get_refreshed_token_return=None): + mock_class_UserAuthorization( + mocker, + sign_in_return=sign_in_return, + get_refreshed_token_return=get_refreshed_token_return, + ) + mock_class_AgenticUserAuthorization( + mocker, + sign_in_return=sign_in_return, + get_refreshed_token_return=get_refreshed_token_return, + ) + + +def sign_in_state_eq(a: Optional[_SignInState], b: Optional[_SignInState]) -> bool: + if a is None and b is None: + return True + if a is None or b is None: + return False + return ( + a.active_handler_id == b.active_handler_id + and a.continuation_activity == b.continuation_activity + ) + + +def create_turn_state(context, token_cache: dict): + + d = {**context.turn_state} + d.update( + { + Authorization._cache_key(context, k): TokenResponse(token=v) + for k, v in token_cache.items() + } + ) + return d + + +def copy_sign_in_state(state: _SignInState) -> _SignInState: + return _SignInState( + active_handler_id=state.active_handler_id, + continuation_activity=( + state.continuation_activity.model_copy() + if state.continuation_activity + else None + ), + ) + + +class TestEnv(FlowStateFixtures): + def setup_method(self): + self.TurnContext = testing_TurnContext + self.UserTokenClient = mock_UserTokenClient + self.ConnectionManager = lambda mocker: MockConnectionManager() + + @pytest.fixture + def context(self, mocker): + return self.TurnContext(mocker) + + @pytest.fixture + def activity(self): + return testing_Activity() + + @pytest.fixture + def baseline_storage(self): + return StorageBaseline(TEST_STORAGE_DATA().dict) + + @pytest.fixture + def storage(self): + return MemoryStorage(STORAGE_DATA.get_init_data()) + + @pytest.fixture + def connection_manager(self, mocker): + return self.ConnectionManager(mocker) + + @pytest.fixture + def auth_handlers(self): + return TEST_AUTH_DATA().auth_handlers + + @pytest.fixture + def authorization(self, connection_manager, storage): + return Authorization(storage, connection_manager, **AGENTIC_ENV_DICT) + + @pytest.fixture(params=[ENV_DICT, AGENTIC_ENV_DICT]) + def env_dict(self, request): + return request.param + + @pytest.fixture(params=[DEFAULTS.auth_handler_id, DEFAULTS.agentic_auth_handler_id]) + def auth_handler_id(self, request): + return request.param + + +class TestAuthorizationSetup(TestEnv): + def test_init_user_auth(self, connection_manager, storage, env_dict): + auth = Authorization(storage, connection_manager, **env_dict) + assert auth._resolve_handler(DEFAULTS.auth_handler_id) is not None + assert isinstance( + auth._resolve_handler(DEFAULTS.auth_handler_id), _UserAuthorization + ) + + def test_init_agentic_auth_not_configured(self, connection_manager, storage): + auth = Authorization(storage, connection_manager, **ENV_DICT) + with pytest.raises(ValueError): + auth._resolve_handler(DEFAULTS.agentic_auth_handler_id) + + def test_init_agentic_auth(self, connection_manager, storage): + auth = Authorization(storage, connection_manager, **AGENTIC_ENV_DICT) + assert auth._resolve_handler(DEFAULTS.agentic_auth_handler_id) is not None + assert isinstance( + auth._resolve_handler(DEFAULTS.agentic_auth_handler_id), + AgenticUserAuthorization, + ) + + @pytest.mark.parametrize( + "auth_handler_id", [DEFAULTS.auth_handler_id, DEFAULTS.agentic_auth_handler_id] + ) + def test_resolve_handler(self, connection_manager, storage, auth_handler_id): + auth = Authorization(storage, connection_manager, **AGENTIC_ENV_DICT) + handler_config = AGENTIC_ENV_DICT["AGENTAPPLICATION"]["USERAUTHORIZATION"][ + "HANDLERS" + ][auth_handler_id] + auth._resolve_handler(auth_handler_id) == AuthHandler( + auth_handler_id, **handler_config + ) + + def test_sign_in_state_key(self, mocker, connection_manager, storage): + auth = Authorization(storage, connection_manager, **ENV_DICT) + context = self.TurnContext(mocker) + key = auth._sign_in_state_key(context) + assert key == f"auth:_SignInState:{DEFAULTS.channel_id}:{DEFAULTS.user_id}" + + +class TestAuthorizationUsage(TestEnv): + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_turn_state, final_turn_state, initial_sign_in_state, auth_handler_id", + [ + [ + {DEFAULTS.auth_handler_id: DEFAULTS.token}, + {}, + None, + DEFAULTS.auth_handler_id, + ], + [ + {DEFAULTS.auth_handler_id: DEFAULTS.token}, + {}, + _SignInState(active_handler_id="some_value"), + DEFAULTS.auth_handler_id, + ], + [ + {DEFAULTS.agentic_auth_handler_id: DEFAULTS.token}, + {DEFAULTS.agentic_auth_handler_id: DEFAULTS.token}, + None, + DEFAULTS.auth_handler_id, + ], + [ + { + DEFAULTS.agentic_auth_handler_id: DEFAULTS.token, + DEFAULTS.auth_handler_id: "value", + }, + {DEFAULTS.auth_handler_id: "value"}, + _SignInState(active_handler_id="some_val"), + DEFAULTS.agentic_auth_handler_id, + ], + ], + ) + async def test_sign_out( + self, + mocker, + storage, + authorization, + context, + initial_turn_state, + final_turn_state, + initial_sign_in_state, + auth_handler_id, + ): + # setup + mock_variants(mocker) + expected_turn_state = create_turn_state(context, final_turn_state) + context.turn_state = create_turn_state(context, initial_turn_state) + if initial_sign_in_state: + await authorization._save_sign_in_state(context, initial_sign_in_state) + + # test + await authorization.sign_out(context, auth_handler_id) + + # verify + assert context.turn_state == expected_turn_state + assert (await authorization._load_sign_in_state(context)) is None + assert authorization._get_cached_token(context, auth_handler_id) is None + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_cache, final_cache, auth_handler_id, expected_auth_handler_id, initial_sign_in_state, final_sign_in_state, sign_in_response", + [ + [ + {DEFAULTS.auth_handler_id: "old_token"}, + {DEFAULTS.auth_handler_id: "valid_token"}, + DEFAULTS.auth_handler_id, + DEFAULTS.auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.auth_handler_id), + None, + _SignInResponse( + token_response=TokenResponse(token="valid_token"), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + {DEFAULTS.auth_handler_id: "old_token"}, + { + DEFAULTS.agentic_auth_handler_id: "valid_token", + DEFAULTS.auth_handler_id: "old_token", + }, + None, + DEFAULTS.agentic_auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.agentic_auth_handler_id), + None, + _SignInResponse( + token_response=TokenResponse(token="valid_token"), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + {DEFAULTS.auth_handler_id: "old_token"}, + {DEFAULTS.auth_handler_id: "valid_token"}, + DEFAULTS.auth_handler_id, + DEFAULTS.auth_handler_id, + None, + None, + _SignInResponse( + token_response=TokenResponse(token="valid_token"), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + {DEFAULTS.auth_handler_id: "old_token"}, + {DEFAULTS.auth_handler_id: "valid_token"}, + None, + DEFAULTS.auth_handler_id, + None, + None, + _SignInResponse( + token_response=TokenResponse(token="valid_token"), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + { + DEFAULTS.agentic_auth_handler_id: "old_token", + DEFAULTS.auth_handler_id: "old_token", + }, + { + DEFAULTS.agentic_auth_handler_id: "valid_token", + DEFAULTS.auth_handler_id: "old_token", + }, + DEFAULTS.agentic_auth_handler_id, + DEFAULTS.agentic_auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.agentic_auth_handler_id), + None, + _SignInResponse( + token_response=TokenResponse(token="valid_token"), + tag=_FlowStateTag.COMPLETE, + ), + ], + [ + { + DEFAULTS.agentic_auth_handler_id: "old_token", + DEFAULTS.auth_handler_id: "old_token", + }, + { + DEFAULTS.agentic_auth_handler_id: "old_token", + DEFAULTS.auth_handler_id: "old_token", + }, + DEFAULTS.agentic_auth_handler_id, + DEFAULTS.agentic_auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.agentic_auth_handler_id), + None, + _SignInResponse( + token_response=TokenResponse(), + tag=_FlowStateTag.FAILURE, + ), + ], + [ + { + DEFAULTS.agentic_auth_handler_id: "old_token", + DEFAULTS.auth_handler_id: "old_token", + }, + { + DEFAULTS.agentic_auth_handler_id: "old_token", + DEFAULTS.auth_handler_id: "old_token", + }, + None, + DEFAULTS.auth_handler_id, + None, + None, + _SignInResponse( + token_response=TokenResponse(), + tag=_FlowStateTag.FAILURE, + ), + ], + ], + ) + async def test_start_or_continue_sign_in_complete_or_failure( + self, + mocker, + storage, + authorization, + context, + initial_cache, + final_cache, + auth_handler_id, + expected_auth_handler_id, + initial_sign_in_state, + final_sign_in_state, + sign_in_response, + ): + # setup + mock_variants(mocker, sign_in_return=sign_in_response) + expected_turn_state = create_turn_state(context, final_cache) + context.turn_state = create_turn_state(context, initial_cache) + if not initial_sign_in_state: + await authorization._delete_sign_in_state(context) + else: + await authorization._save_sign_in_state(context, initial_sign_in_state) + + # test + + res = await authorization._start_or_continue_sign_in( + context, None, auth_handler_id + ) + + # verify + assert res.tag == sign_in_response.tag + assert res.token_response == sign_in_response.token_response + + authorization._resolve_handler( + expected_auth_handler_id + )._sign_in.assert_called_once_with(context) + assert (await authorization._load_sign_in_state(context)) is None + assert context.turn_state == expected_turn_state + + @pytest.fixture(params=[_FlowStateTag.BEGIN, _FlowStateTag.CONTINUE]) + def pending_tag(self, request): + return request.param + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_cache, auth_handler_id, expected_auth_handler_id, initial_sign_in_state", + [ + [ + {DEFAULTS.agentic_auth_handler_id: "old_token"}, + DEFAULTS.auth_handler_id, + DEFAULTS.auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.auth_handler_id), + ], + [ + {DEFAULTS.auth_handler_id: "old_token"}, + None, + DEFAULTS.agentic_auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.agentic_auth_handler_id), + ], + [ + {}, + DEFAULTS.auth_handler_id, + DEFAULTS.auth_handler_id, + None, + ], + [ + {DEFAULTS.auth_handler_id: "old_token"}, + None, + DEFAULTS.auth_handler_id, + None, + ], + [ + {}, + DEFAULTS.agentic_auth_handler_id, + DEFAULTS.auth_handler_id, + _SignInState(active_handler_id=DEFAULTS.auth_handler_id), + ], + ], + ) + async def test_start_or_continue_sign_in_pending( + self, + mocker, + storage, + authorization, + context, + initial_cache, + auth_handler_id, + expected_auth_handler_id, + initial_sign_in_state, + pending_tag, + ): + # setup + mock_variants(mocker, sign_in_return=_SignInResponse(tag=pending_tag)) + expected_turn_state = create_turn_state(context, initial_cache) + context.turn_state = expected_turn_state + if not initial_sign_in_state: + await authorization._delete_sign_in_state(context) + else: + await authorization._save_sign_in_state(context, initial_sign_in_state) + + # test + + res = await authorization._start_or_continue_sign_in( + context, None, auth_handler_id + ) + + # verify + assert res.tag == pending_tag + assert not res.token_response + + authorization._resolve_handler( + expected_auth_handler_id + )._sign_in.assert_called_once_with(context) + final_sign_in_state = await authorization._load_sign_in_state(context) + assert final_sign_in_state.continuation_activity == context.activity + assert final_sign_in_state.active_handler_id == expected_auth_handler_id + assert context.turn_state == expected_turn_state + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_state, initial_cache, handler_id, expected_handler_id, refresh_token, expected", + [ + [ # no cached token + _SignInState(active_handler_id="value"), + {DEFAULTS.auth_handler_id: "token"}, + DEFAULTS.agentic_auth_handler_id, + DEFAULTS.agentic_auth_handler_id, + TokenResponse(), + TokenResponse(), + ], + [ # no cached token and default handler id resolution + _SignInState(active_handler_id="value"), + {DEFAULTS.agentic_auth_handler_id: "token"}, + "", + DEFAULTS.auth_handler_id, + TokenResponse(), + TokenResponse(), + ], + [ # no cached token pt.2 + _SignInState(active_handler_id=DEFAULTS.auth_handler_id), + {DEFAULTS.agentic_auth_handler_id: "token"}, + DEFAULTS.auth_handler_id, + DEFAULTS.auth_handler_id, + TokenResponse(), + TokenResponse(), + ], + [ # refreshed, new token + _SignInState(active_handler_id="value"), + {DEFAULTS.agentic_auth_handler_id: make_jwt()}, + DEFAULTS.agentic_auth_handler_id, + DEFAULTS.agentic_auth_handler_id, + TokenResponse(token=DEFAULTS.token), + TokenResponse(token=DEFAULTS.token), + ], + ], + ) + async def test_get_token( + self, + mocker, + authorization, + context, + storage, + initial_state, + initial_cache, + handler_id, + expected_handler_id, + refresh_token, + expected, + ): + # setup + mock_variants(mocker, get_refreshed_token_return=refresh_token) + expected_turn_state = create_turn_state(context, initial_cache) + context.turn_state = expected_turn_state + if not initial_state: + await authorization._delete_sign_in_state(context) + else: + await authorization._save_sign_in_state(context, initial_state) + + # test + res = await authorization.get_token(context, handler_id) + assert res == expected + + if handler_id and refresh_token: + authorization._resolve_handler( + expected_handler_id + ).get_refreshed_token.assert_called_once_with(context, None, None) + + final_state = await authorization._load_sign_in_state(context) + assert sign_in_state_eq(initial_state, final_state) + assert context.turn_state == expected_turn_state + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_state, initial_cache, handler_id, refreshed, refresh_token", + [ + [ # no cached token + None, + {DEFAULTS.auth_handler_id: "token"}, + DEFAULTS.agentic_auth_handler_id, + False, + TokenResponse(), + ], + [ # no cached token and default handler id resolution + None, + {DEFAULTS.agentic_auth_handler_id: "token"}, + "", + False, + TokenResponse(), + ], + [ # no cached token pt.2 + _SignInState(active_handler_id=DEFAULTS.auth_handler_id), + {DEFAULTS.agentic_auth_handler_id: "token"}, + DEFAULTS.auth_handler_id, + True, + TokenResponse(), + ], + [ # refreshed, new token + _SignInState(active_handler_id=DEFAULTS.auth_handler_id), + {DEFAULTS.agentic_auth_handler_id: DEFAULTS.token}, + DEFAULTS.agentic_auth_handler_id, + True, + TokenResponse(token=DEFAULTS.token), + ], + ], + ) + async def test_exchange_token( + self, + mocker, + authorization, + context, + storage, + initial_state, + initial_cache, + handler_id, + refreshed, + refresh_token, + ): + # setup + mock_variants(mocker, get_refreshed_token_return=refresh_token) + expected_turn_state = create_turn_state(context, initial_cache) + context.turn_state = expected_turn_state + if not initial_state: + await authorization._delete_sign_in_state(context) + else: + await authorization._save_sign_in_state(context, initial_state) + + res = await authorization.exchange_token( + context, + auth_handler_id=handler_id, + exchange_connection="some_connection", + scopes=["scope1", "scope2"], + ) + assert res == refresh_token + + final_state = await authorization._load_sign_in_state(context) + assert sign_in_state_eq(initial_state, final_state) + if handler_id and refresh_token: + authorization._resolve_handler( + handler_id + ).get_refreshed_token.assert_called_once_with( + context, "some_connection", ["scope1", "scope2"] + ) + + final_state = await authorization._load_sign_in_state(context) + assert sign_in_state_eq(initial_state, final_state) + assert context.turn_state == expected_turn_state + + @pytest.mark.asyncio + async def test_on_turn_auth_intercept_no_intercept( + self, storage, authorization, context + ): + await authorization._delete_sign_in_state(context) + + intercepts, continuation_activity = await authorization._on_turn_auth_intercept( + context, None + ) + + assert not continuation_activity + assert not intercepts + + final_state = await authorization._load_sign_in_state(context) + + assert sign_in_state_eq(final_state, None) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "sign_in_response", + [ + _SignInResponse(tag=_FlowStateTag.BEGIN), + _SignInResponse(tag=_FlowStateTag.CONTINUE), + _SignInResponse(tag=_FlowStateTag.FAILURE), + ], + ) + async def test_on_turn_auth_intercept_with_intercept_incomplete( + self, mocker, storage, authorization, context, sign_in_response, auth_handler_id + ): + mock_class_Authorization( + mocker, start_or_continue_sign_in_return=sign_in_response + ) + + initial_cache = {"some_handler": "old_token"} + expected_cache = create_turn_state(context, initial_cache) + context.turn_state = expected_cache + + initial_state = _SignInState( + active_handler_id=auth_handler_id, + continuation_activity=Activity( + type=ActivityTypes.message, text="old activity" + ), + ) + await authorization._save_sign_in_state( + context, copy_sign_in_state(initial_state) + ) + + intercepts, continuation_activity = await authorization._on_turn_auth_intercept( + context, auth_handler_id + ) + + assert not continuation_activity + assert intercepts + + final_state = await authorization._load_sign_in_state(context) + assert sign_in_state_eq(final_state, initial_state) + assert context.turn_state == expected_cache + + @pytest.mark.asyncio + async def test_on_turn_auth_intercept_with_intercept_complete( + self, mocker, storage, authorization, context, auth_handler_id + ): + mock_class_Authorization( + mocker, + start_or_continue_sign_in_return=_SignInResponse( + tag=_FlowStateTag.COMPLETE + ), + ) + + initial_cache = {"some_handler": "old_token"} + expected_cache = create_turn_state(context, initial_cache) + context.turn_state = expected_cache + + old_activity = Activity(type=ActivityTypes.message, text="old activity") + initial_state = _SignInState( + active_handler_id=auth_handler_id, continuation_activity=old_activity + ) + await authorization._save_sign_in_state( + context, copy_sign_in_state(initial_state) + ) + + intercepts, continuation_activity = await authorization._on_turn_auth_intercept( + context, auth_handler_id + ) + + assert continuation_activity == old_activity + assert intercepts + + final_state = await authorization._load_sign_in_state(context) + assert sign_in_state_eq(final_state, initial_state) + assert context.turn_state == expected_cache diff --git a/tests/hosting_core/app/oauth/test_sign_in_response.py b/tests/hosting_core/app/oauth/test_sign_in_response.py new file mode 100644 index 00000000..30c52fd8 --- /dev/null +++ b/tests/hosting_core/app/oauth/test_sign_in_response.py @@ -0,0 +1,11 @@ +from microsoft_agents.hosting.core.app.oauth import _SignInResponse +from microsoft_agents.hosting.core._oauth import _FlowStateTag + + +def test_sign_in_response_sign_in_complete(): + assert _SignInResponse(tag=_FlowStateTag.BEGIN).sign_in_complete() == False + assert _SignInResponse(tag=_FlowStateTag.CONTINUE).sign_in_complete() == False + assert _SignInResponse(tag=_FlowStateTag.FAILURE).sign_in_complete() == False + assert _SignInResponse().sign_in_complete() == False + assert _SignInResponse(tag=_FlowStateTag.NOT_STARTED).sign_in_complete() == True + assert _SignInResponse(tag=_FlowStateTag.COMPLETE).sign_in_complete() == True diff --git a/tests/hosting_core/app/test_agent_application.py b/tests/hosting_core/app/test_agent_application.py new file mode 100644 index 00000000..b5bad921 --- /dev/null +++ b/tests/hosting_core/app/test_agent_application.py @@ -0,0 +1,33 @@ +# from microsoft_agents.authentication.msal.msal_connection_manager import MsalConnectionManager +# from microsoft_agents.hosting.core.turn_context import TurnContext +# import pytest + +# from microsoft_agents.authentication.msal import MsalAuthentication +# from microsoft_agents.hosting.core import ( +# MemoryStorage, +# AgentApplication, +# ApplicationOptions, +# Connections +# ) + +# # def mock_send_activity(mocker): +# # mocker.patch.object(TurnContext, 'send_activity', new=) + +# class TestUtils: + +# @pytest.fixture +# def options(self): +# return ApplicationOptions() + +# @pytest.fixture +# def storage(self): +# return MemoryStorage() + +# @pytest.fixture +# def connection_manager(self): +# return MsalConnectionManager() + + +# class TestAgentApplication: + +# pass diff --git a/tests/hosting_core/app/test_authorization.py b/tests/hosting_core/app/test_authorization.py deleted file mode 100644 index effacf87..00000000 --- a/tests/hosting_core/app/test_authorization.py +++ /dev/null @@ -1,540 +0,0 @@ -import pytest -from datetime import datetime -import jwt - -from microsoft_agents.activity import ActivityTypes, TokenResponse - -from microsoft_agents.hosting.core import ( - FlowStorageClient, - FlowErrorTag, - FlowStateTag, - FlowState, - FlowResponse, - OAuthFlow, - Authorization, - MemoryStorage, -) - -from tests._common.storage.utils import StorageBaseline - -# test constants -from tests._common.data import ( - TEST_FLOW_DATA, - TEST_AUTH_DATA, - TEST_STORAGE_DATA, - TEST_DEFAULTS, - create_test_auth_handler, -) -from tests._common.fixtures import FlowStateFixtures -from tests._common.testing_objects import ( - TestingConnectionManager as MockConnectionManager, - mock_class_OAuthFlow, - mock_UserTokenClient, -) -from tests.hosting_core._common import flow_state_eq - -DEFAULTS = TEST_DEFAULTS() -FLOW_DATA = TEST_FLOW_DATA() -STORAGE_DATA = TEST_STORAGE_DATA() - - -def testing_TurnContext( - mocker, - channel_id=DEFAULTS.channel_id, - user_id=DEFAULTS.user_id, - user_token_client=None, -): - if not user_token_client: - user_token_client = mock_UserTokenClient(mocker) - - turn_context = mocker.Mock() - turn_context.activity.channel_id = channel_id - turn_context.activity.from_property.id = user_id - turn_context.activity.type = ActivityTypes.message - turn_context.adapter.USER_TOKEN_CLIENT_KEY = "__user_token_client" - turn_context.adapter.AGENT_IDENTITY_KEY = "__agent_identity_key" - agent_identity = mocker.Mock() - agent_identity.claims = {"aud": DEFAULTS.ms_app_id} - turn_context.turn_state = { - "__user_token_client": user_token_client, - "__agent_identity_key": agent_identity, - } - return turn_context - - -class TestEnv(FlowStateFixtures): - def setup_method(self): - self.TurnContext = testing_TurnContext - self.UserTokenClient = mock_UserTokenClient - self.ConnectionManager = lambda mocker: MockConnectionManager() - - @pytest.fixture - def turn_context(self, mocker): - return self.TurnContext(mocker) - - @pytest.fixture - def baseline_storage(self): - return StorageBaseline(TEST_STORAGE_DATA().dict) - - @pytest.fixture - def storage(self): - return MemoryStorage(STORAGE_DATA.get_init_data()) - - @pytest.fixture - def connection_manager(self, mocker): - return self.ConnectionManager(mocker) - - @pytest.fixture - def auth_handlers(self): - return TEST_AUTH_DATA().auth_handlers - - @pytest.fixture - def authorization(self, connection_manager, storage, auth_handlers): - return Authorization(storage, connection_manager, auth_handlers) - - -class TestAuthorization(TestEnv): - def test_init_configuration_variants( - self, storage, connection_manager, auth_handlers - ): - """Test initialization of authorization with different configuration variants.""" - AGENTAPPLICATION = { - "USERAUTHORIZATION": { - "HANDLERS": { - handler_name: { - "SETTINGS": { - "title": handler.title, - "text": handler.text, - "abs_oauth_connection_name": handler.abs_oauth_connection_name, - "obo_connection_name": handler.obo_connection_name, - } - } - for handler_name, handler in auth_handlers.items() - } - } - } - auth_with_config_obj = Authorization( - storage, - connection_manager, - auth_handlers=None, - AGENTAPPLICATION=AGENTAPPLICATION, - ) - auth_with_handlers_list = Authorization( - storage, connection_manager, auth_handlers=auth_handlers - ) - for auth_handler_name in auth_handlers.keys(): - auth_handler_a = auth_with_config_obj.resolve_handler(auth_handler_name) - auth_handler_b = auth_with_handlers_list.resolve_handler(auth_handler_name) - - assert auth_handler_a.name == auth_handler_b.name - assert auth_handler_a.title == auth_handler_b.title - assert auth_handler_a.text == auth_handler_b.text - assert ( - auth_handler_a.abs_oauth_connection_name - == auth_handler_b.abs_oauth_connection_name - ) - assert ( - auth_handler_a.obo_connection_name == auth_handler_b.obo_connection_name - ) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id, channel_id, user_id", - [["missing", "webchat", "Alice"], ["handler", "teams", "Bob"]], - ) - async def test_open_flow_value_error( - self, mocker, authorization, auth_handler_id, channel_id, user_id - ): - """Test opening a flow with a missing auth handler.""" - context = self.TurnContext(mocker, channel_id=channel_id, user_id=user_id) - with pytest.raises(ValueError): - async with authorization.open_flow(context, auth_handler_id): - pass - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "auth_handler_id, channel_id, user_id", - [ - ["", "webchat", "Alice"], - ["graph", "teams", "Bob"], - ["slack", "webchat", "Chuck"], - ], - ) - async def test_open_flow_readonly( - self, - mocker, - storage, - connection_manager, - auth_handlers, - auth_handler_id, - channel_id, - user_id, - ): - """Test opening a flow and not modifying it.""" - # setup - context = self.TurnContext(mocker, channel_id=channel_id, user_id=user_id) - auth = Authorization(storage, connection_manager, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - - # verify - actual_flow_state = await flow_storage_client.read( - auth.resolve_handler(auth_handler_id).name - ) - assert actual_flow_state == expected_flow_state - - @pytest.mark.asyncio - async def test_open_flow_success_modified_complete_flow( - self, - mocker, - storage, - connection_manager, - auth_handlers, - ): - # mock - channel_id = "teams" - user_id = "Alice" - auth_handler_id = "graph" - - user_token_client = self.UserTokenClient( - mocker, get_token_return=DEFAULTS.token - ) - context = self.TurnContext( - mocker, - channel_id=channel_id, - user_id=user_id, - user_token_client=user_token_client, - ) - - # setup - context.activity.type = ActivityTypes.message - context.activity.text = "123456" - - auth = Authorization(storage, connection_manager, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - expected_flow_state.tag = FlowStateTag.COMPLETE - expected_flow_state.user_token = DEFAULTS.token - - flow_response = await flow.begin_or_continue_flow(context.activity) - res_flow_state = flow_response.flow_state - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - expected_flow_state.expiration = actual_flow_state.expiration - assert flow_state_eq(actual_flow_state, expected_flow_state) - assert flow_state_eq(res_flow_state, expected_flow_state) - - @pytest.mark.asyncio - async def test_open_flow_success_modified_failure( - self, - mocker, - storage, - connection_manager, - auth_handlers, - ): - # setup - channel_id = "teams" - user_id = "Bob" - auth_handler_id = "slack" - - context = self.TurnContext(mocker, channel_id=channel_id, user_id=user_id) - context.activity.text = "invalid_magic_code" - - auth = Authorization(storage, connection_manager, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - expected_flow_state.tag = FlowStateTag.FAILURE - expected_flow_state.attempts_remaining = 0 - - flow_response = await flow.begin_or_continue_flow(context.activity) - res_flow_state = flow_response.flow_state - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - - assert flow_response.flow_error_tag == FlowErrorTag.MAGIC_FORMAT - assert flow_state_eq(res_flow_state, expected_flow_state) - assert flow_state_eq(actual_flow_state, expected_flow_state) - - @pytest.mark.asyncio - async def test_open_flow_success_modified_signout( - self, mocker, storage, connection_manager, auth_handlers - ): - # setup - channel_id = "webchat" - user_id = "Alice" - auth_handler_id = "graph" - - context = self.TurnContext(mocker, channel_id=channel_id, user_id=user_id) - - auth = Authorization(storage, connection_manager, auth_handlers) - flow_storage_client = FlowStorageClient(channel_id, user_id, storage) - - # test - async with auth.open_flow(context, auth_handler_id) as flow: - expected_flow_state = flow.flow_state - expected_flow_state.tag = FlowStateTag.NOT_STARTED - expected_flow_state.user_token = "" - - await flow.sign_out() - - # verify - actual_flow_state = await flow_storage_client.read(auth_handler_id) - assert flow_state_eq(actual_flow_state, expected_flow_state) - - @pytest.mark.asyncio - async def test_get_token_success(self, mocker, authorization): - user_token_client = self.UserTokenClient(mocker, get_token_return="token") - context = self.TurnContext( - mocker, - channel_id="__channel_id", - user_id="__user_id", - user_token_client=user_token_client, - ) - assert await authorization.get_token(context, "slack") == TokenResponse( - token="token" - ) - user_token_client.user_token.get_token.assert_called_once() - - @pytest.mark.asyncio - async def test_get_token_empty_response(self, mocker, authorization): - user_token_client = self.UserTokenClient( - mocker, get_token_return=TokenResponse() - ) - context = self.TurnContext( - mocker, - channel_id="__channel_id", - user_id="__user_id", - user_token_client=user_token_client, - ) - assert await authorization.get_token(context, "graph") == TokenResponse() - user_token_client.user_token.get_token.assert_called_once() - - @pytest.mark.asyncio - async def test_get_token_error( - self, turn_context, storage, connection_manager, auth_handlers - ): - auth = Authorization(storage, connection_manager, auth_handlers) - with pytest.raises(ValueError): - await auth.get_token( - turn_context, DEFAULTS.missing_abs_oauth_connection_name - ) - - @pytest.mark.asyncio - async def test_exchange_token_no_token(self, mocker, turn_context, authorization): - mock_class_OAuthFlow(mocker, get_user_token_return=TokenResponse()) - res = await authorization.exchange_token(turn_context, ["scope"], "github") - assert res == TokenResponse() - - @pytest.mark.asyncio - async def test_exchange_token_not_exchangeable( - self, mocker, turn_context, authorization - ): - token = jwt.encode({"aud": "invalid://botframework.test.api"}, "") - mock_class_OAuthFlow( - mocker, - get_user_token_return=TokenResponse(connection_name="github", token=token), - ) - res = await authorization.exchange_token(turn_context, ["scope"], "github") - assert res == TokenResponse() - - @pytest.mark.asyncio - async def test_exchange_token_valid_exchangeable(self, mocker, authorization): - # setup - token = jwt.encode({"aud": "api://botframework.test.api"}, "") - mock_class_OAuthFlow( - mocker, - get_user_token_return=TokenResponse(connection_name="github", token=token), - ) - user_token_client = self.UserTokenClient( - mocker, get_token_return="github-obo-connection-obo-token" - ) - turn_context = self.TurnContext(mocker, user_token_client=user_token_client) - # test - res = await authorization.exchange_token(turn_context, ["scope"], "github") - assert res == TokenResponse(token="github-obo-connection-obo-token") - - @pytest.mark.asyncio - async def test_get_active_flow_state(self, mocker, authorization): - context = self.TurnContext(mocker, channel_id="webchat", user_id="Alice") - actual_flow_state = await authorization.get_active_flow_state(context) - assert actual_flow_state == STORAGE_DATA.dict["auth/webchat/Alice/github"] - - @pytest.mark.asyncio - async def test_get_active_flow_state_missing(self, mocker, authorization): - context = self.TurnContext( - mocker, channel_id="__channel_id", user_id="__user_id" - ) - res = await authorization.get_active_flow_state(context) - assert res is None - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_success(self, mocker, authorization): - # robrandao: TODO -> lower priority -> more testing here - # setup - mock_class_OAuthFlow( - mocker, - begin_or_continue_flow_return=FlowResponse( - token_response=TokenResponse(token="token"), - flow_state=FlowState( - tag=FlowStateTag.COMPLETE, auth_handler_id="github" - ), - ), - ) - context = self.TurnContext(mocker, channel_id="webchat", user_id="Alice") - context.dummy_val = None - - def on_sign_in_success(context, turn_state, auth_handler_id): - context.dummy_val = auth_handler_id - - def on_sign_in_failure(context, turn_state, auth_handler_id, err): - context.dummy_val = str(err) - - # test - authorization.on_sign_in_success(on_sign_in_success) - authorization.on_sign_in_failure(on_sign_in_failure) - flow_response = await authorization.begin_or_continue_flow( - context, None, "github" - ) - assert context.dummy_val == "github" - assert flow_response.token_response == TokenResponse(token="token") - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_already_completed( - self, mocker, authorization - ): - # robrandao: TODO -> lower priority -> more testing here - # setup - context = self.TurnContext(mocker, channel_id="webchat", user_id="Alice") - - context.dummy_val = None - - def on_sign_in_success(context, turn_state, auth_handler_id): - context.dummy_val = auth_handler_id - - def on_sign_in_failure(context, turn_state, auth_handler_id, err): - context.dummy_val = str(err) - - # test - authorization.on_sign_in_success(on_sign_in_success) - authorization.on_sign_in_failure(on_sign_in_failure) - flow_response = await authorization.begin_or_continue_flow( - context, None, "graph" - ) - assert context.dummy_val == None - assert flow_response.token_response == TokenResponse(token="test_token") - assert flow_response.continuation_activity is None - - @pytest.mark.asyncio - async def test_begin_or_continue_flow_failure(self, mocker, authorization): - # robrandao: TODO -> lower priority -> more testing here - # setup - mock_class_OAuthFlow( - mocker, - begin_or_continue_flow_return=FlowResponse( - token_response=TokenResponse(token="token"), - flow_state=FlowState( - tag=FlowStateTag.FAILURE, auth_handler_id="github" - ), - flow_error_tag=FlowErrorTag.MAGIC_FORMAT, - ), - ) - context = self.TurnContext(mocker, channel_id="webchat", user_id="Alice") - context.dummy_val = None - - def on_sign_in_success(context, turn_state, auth_handler_id): - context.dummy_val = auth_handler_id - - def on_sign_in_failure(context, turn_state, auth_handler_id, err): - context.dummy_val = str(err) - - # test - authorization.on_sign_in_success(on_sign_in_success) - authorization.on_sign_in_failure(on_sign_in_failure) - flow_response = await authorization.begin_or_continue_flow( - context, None, "github" - ) - assert context.dummy_val == "FlowErrorTag.MAGIC_FORMAT" - assert flow_response.token_response == TokenResponse(token="token") - - @pytest.mark.parametrize("auth_handler_id", ["graph", "github"]) - def test_resolve_handler_specified( - self, authorization, auth_handlers, auth_handler_id - ): - assert ( - authorization.resolve_handler(auth_handler_id) - == auth_handlers[auth_handler_id] - ) - - def test_resolve_handler_error(self, authorization): - with pytest.raises(ValueError): - authorization.resolve_handler("missing-handler") - - def test_resolve_handler_first(self, authorization, auth_handlers): - assert authorization.resolve_handler() == next(iter(auth_handlers.values())) - - @pytest.mark.asyncio - async def test_sign_out_individual( - self, - mocker, - storage, - connection_manager, - auth_handlers, - ): - # setup - mock_class_OAuthFlow(mocker) - storage_client = FlowStorageClient("teams", "Alice", storage) - context = self.TurnContext(mocker, channel_id="teams", user_id="Alice") - auth = Authorization(storage, connection_manager, auth_handlers) - - # test - await auth.sign_out(context, "graph") - - # verify - assert ( - await storage.read([storage_client.key("graph")], target_cls=FlowState) - == {} - ) - OAuthFlow.sign_out.assert_called_once() - - @pytest.mark.asyncio - async def test_sign_out_all( - self, - mocker, - storage, - connection_manager, - auth_handlers, - ): - # setup - mock_class_OAuthFlow(mocker) - context = self.TurnContext(mocker, channel_id="webchat", user_id="Alice") - storage_client = FlowStorageClient("webchat", "Alice", storage) - auth = Authorization(storage, connection_manager, auth_handlers) - - # test - await auth.sign_out(context) - - # verify - assert ( - await storage.read([storage_client.key("graph")], target_cls=FlowState) - == {} - ) - assert ( - await storage.read([storage_client.key("github")], target_cls=FlowState) - == {} - ) - assert ( - await storage.read([storage_client.key("slack")], target_cls=FlowState) - == {} - ) - OAuthFlow.sign_out.assert_called() # ignore red squiggly -> mocked diff --git a/tests/hosting_core/storage/test_transcript_logger_middleware.py b/tests/hosting_core/storage/test_transcript_logger_middleware.py index d980e3c7..f63db030 100644 --- a/tests/hosting_core/storage/test_transcript_logger_middleware.py +++ b/tests/hosting_core/storage/test_transcript_logger_middleware.py @@ -6,11 +6,21 @@ from microsoft_agents.activity import Activity, ActivityEventNames, ActivityTypes from microsoft_agents.hosting.core.authorization.claims_identity import ClaimsIdentity from microsoft_agents.hosting.core.middleware_set import TurnContext -from microsoft_agents.hosting.core.storage.transcript_logger import ConsoleTranscriptLogger, FileTranscriptLogger, TranscriptLoggerMiddleware -from microsoft_agents.hosting.core.storage.transcript_memory_store import TranscriptMemoryStore +from microsoft_agents.hosting.core.storage.transcript_logger import ( + ConsoleTranscriptLogger, + FileTranscriptLogger, + TranscriptLoggerMiddleware, +) +from microsoft_agents.hosting.core.storage.transcript_memory_store import ( + TranscriptMemoryStore, +) import pytest -from tests._common.testing_objects.adapters.testing_adapter import AgentCallbackHandler, TestingAdapter +from tests._common.testing_objects.adapters.testing_adapter import ( + AgentCallbackHandler, + TestingAdapter, +) + @pytest.mark.asyncio async def test_should_round_trip_via_middleware(): @@ -18,23 +28,23 @@ async def test_should_round_trip_via_middleware(): conversation_id = "id.1" transcript_middleware = TranscriptLoggerMiddleware(transcript_store) channelName = "Channel1" - + adapter = TestingAdapter(channelName) - adapter.use(transcript_middleware) + adapter.use(transcript_middleware) id = ClaimsIdentity({}, True) async def callback(tc): print("process callback") a1 = adapter.make_activity("some random text") - a1.conversation.id = conversation_id # Make sure the conversation ID is set + a1.conversation.id = conversation_id # Make sure the conversation ID is set await adapter.process_activity(id, a1, callback) - + transcriptAndContinuationToken = await transcript_store.get_transcript_activities( channelName, conversation_id ) - + transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] @@ -44,12 +54,13 @@ async def callback(tc): assert transcript[0].text == a1.text assert continuationToken is None + @pytest.mark.asyncio async def test_should_write_to_file(): fileName = "test_transcript.log" - if os.path.exists(fileName): # Check if the file exists - os.remove(fileName) # Delete the file + if os.path.exists(fileName): # Check if the file exists + os.remove(fileName) # Delete the file assert not os.path.exists(fileName), "file already exists." @@ -57,9 +68,9 @@ async def test_should_write_to_file(): conversation_id = "id.1" transcript_middleware = TranscriptLoggerMiddleware(file_store) channelName = "Channel1" - + adapter = TestingAdapter(channelName) - adapter.use(transcript_middleware) + adapter.use(transcript_middleware) id = ClaimsIdentity({}, True) async def callback(tc): @@ -67,8 +78,8 @@ async def callback(tc): textInActivity = "some random text" a1 = adapter.make_activity(textInActivity) - a1.conversation.id = conversation_id # Make sure the conversation ID is set - + a1.conversation.id = conversation_id # Make sure the conversation ID is set + # This round-trips out to the File logger which does the actual write await adapter.process_activity(id, a1, callback) @@ -77,6 +88,7 @@ async def callback(tc): assert os.path.isfile(fileName), "file is not a file." assert os.path.getsize(fileName) > 0, "file is empty" + @pytest.mark.asyncio async def test_should_write_to_console(): @@ -84,9 +96,9 @@ async def test_should_write_to_console(): conversation_id = "id.1" transcript_middleware = TranscriptLoggerMiddleware(store) channelName = "Channel1" - + adapter = TestingAdapter(channelName) - adapter.use(transcript_middleware) + adapter.use(transcript_middleware) id = ClaimsIdentity({}, True) async def callback(tc): @@ -94,9 +106,9 @@ async def callback(tc): textInActivity = "some random text" a1 = adapter.make_activity(textInActivity) - a1.conversation.id = conversation_id # Make sure the conversation ID is set - + a1.conversation.id = conversation_id # Make sure the conversation ID is set + # This round-trips out to the console logger which does the actual write await adapter.process_activity(id, a1, callback) - #check the console by hand. \ No newline at end of file + # check the console by hand. diff --git a/tests/hosting_core/storage/test_transcript_store_memory.py b/tests/hosting_core/storage/test_transcript_store_memory.py index 2733eb85..7d11e752 100644 --- a/tests/hosting_core/storage/test_transcript_store_memory.py +++ b/tests/hosting_core/storage/test_transcript_store_memory.py @@ -3,31 +3,39 @@ from datetime import datetime, timezone import pytest -from microsoft_agents.hosting.core.storage.transcript_memory_store import TranscriptMemoryStore +from microsoft_agents.hosting.core.storage.transcript_memory_store import ( + TranscriptMemoryStore, +) from microsoft_agents.activity import Activity, ConversationAccount + @pytest.mark.asyncio async def test_get_transcript_empty(): store = TranscriptMemoryStore() - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert transcript == [] assert continuationToken is None + @pytest.mark.asyncio async def test_log_activity_add_one_activity(): store = TranscriptMemoryStore() activity = Activity.create_message_activity() activity.text = "Activity 1" activity.channel_id = "Channel 1" - activity.conversation = ConversationAccount( id="Conversation 1" ) - - # Add one activity and make sure it's there and comes back + activity.conversation = ConversationAccount(id="Conversation 1") + + # Add one activity and make sure it's there and comes back await store.log_activity(activity) # Ask for the activity we just added - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] @@ -38,37 +46,44 @@ async def test_log_activity_add_one_activity(): assert continuationToken is None # Ask for a channel that doesn't exist and make sure we get nothing - transcriptAndContinuationToken = await store.get_transcript_activities("Invalid", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Invalid", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] - assert transcript == [] + assert transcript == [] assert continuationToken is None # Ask for a ConversationID that doesn't exist and make sure we get nothing - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "INVALID") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "INVALID" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] - assert transcript == [] + assert transcript == [] assert continuationToken is None + @pytest.mark.asyncio async def test_log_activity_add_two_activity_same_conversation(): store = TranscriptMemoryStore() activity1 = Activity.create_message_activity() activity1.text = "Activity 1" activity1.channel_id = "Channel 1" - activity1.conversation = ConversationAccount( id="Conversation 1" ) + activity1.conversation = ConversationAccount(id="Conversation 1") activity2 = Activity.create_message_activity() activity2.text = "Activity 2" activity2.channel_id = "Channel 1" - activity2.conversation = ConversationAccount( id="Conversation 1" ) + activity2.conversation = ConversationAccount(id="Conversation 1") await store.log_activity(activity1) - await store.log_activity(activity2) + await store.log_activity(activity2) # Ask for the activity we just added - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] @@ -83,50 +98,57 @@ async def test_log_activity_add_two_activity_same_conversation(): assert continuationToken is None + @pytest.mark.asyncio async def test_log_activity_add_two_activity_same_conversation(): store = TranscriptMemoryStore() activity1 = Activity.create_message_activity() activity1.text = "Activity 1" activity1.channel_id = "Channel 1" - activity1.conversation = ConversationAccount( id="Conversation 1" ) - activity1.timestamp = datetime(2000, 1, 1, 12, 0, 0 , tzinfo=timezone.utc) + activity1.conversation = ConversationAccount(id="Conversation 1") + activity1.timestamp = datetime(2000, 1, 1, 12, 0, 0, tzinfo=timezone.utc) activity2 = Activity.create_message_activity() activity2.text = "Activity 2" activity2.channel_id = "Channel 1" - activity2.conversation = ConversationAccount( id="Conversation 1" ) - activity2.timestamp = datetime(2010, 1, 1, 12, 0, 1 , tzinfo=timezone.utc) + activity2.conversation = ConversationAccount(id="Conversation 1") + activity2.timestamp = datetime(2010, 1, 1, 12, 0, 1, tzinfo=timezone.utc) activity3 = Activity.create_message_activity() activity3.text = "Activity 2" activity3.channel_id = "Channel 1" - activity3.conversation = ConversationAccount( id="Conversation 1" ) - activity3.timestamp = datetime(2020, 1, 1, 12, 0, 1 , tzinfo=timezone.utc) + activity3.conversation = ConversationAccount(id="Conversation 1") + activity3.timestamp = datetime(2020, 1, 1, 12, 0, 1, tzinfo=timezone.utc) - await store.log_activity(activity1) # 2000 - await store.log_activity(activity2) # 2010 - await store.log_activity(activity3) # 2020 + await store.log_activity(activity1) # 2000 + await store.log_activity(activity2) # 2010 + await store.log_activity(activity3) # 2020 # Ask for the activities we just added - date1 = datetime(1999, 1, 1, 12, 0, 0 , tzinfo=timezone.utc) - date2 = datetime(2009, 1, 1, 12, 0, 0 , tzinfo=timezone.utc) - date3 = datetime(2019, 1, 1, 12, 0, 0 , tzinfo=timezone.utc) + date1 = datetime(1999, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + date2 = datetime(2009, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + date3 = datetime(2019, 1, 1, 12, 0, 0, tzinfo=timezone.utc) # ask for everything after 1999. Should get all 3 activities - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1", None, date1) + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1", None, date1 + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert len(transcript) == 3 # ask for everything after 2009. Should get 2 activities - the 2010 and 2020 activities - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1", None, date2) + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1", None, date2 + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert len(transcript) == 2 # ask for everything after 2019. Should only get the 2020 activity - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1", None, date3) + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1", None, date3 + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert len(transcript) == 1 @@ -138,18 +160,20 @@ async def test_log_activity_add_two_activity_two_conversation(): activity1 = Activity.create_message_activity() activity1.text = "Activity 1 Channel 1 Conversation 1" activity1.channel_id = "Channel 1" - activity1.conversation = ConversationAccount( id="Conversation 1" ) + activity1.conversation = ConversationAccount(id="Conversation 1") activity2 = Activity.create_message_activity() activity2.text = "Activity 1 Channel 1 Conversation 2" activity2.channel_id = "Channel 1" - activity2.conversation = ConversationAccount( id="Conversation 2" ) + activity2.conversation = ConversationAccount(id="Conversation 2") await store.log_activity(activity1) - await store.log_activity(activity2) + await store.log_activity(activity2) # Ask for the activity we just added - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] @@ -160,7 +184,9 @@ async def test_log_activity_add_two_activity_two_conversation(): assert continuationToken is None # Now grab Conversation 2 - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 2") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 2" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] @@ -170,19 +196,22 @@ async def test_log_activity_add_two_activity_two_conversation(): assert transcript[0].text == activity2.text assert continuationToken is None + @pytest.mark.asyncio async def test_delete_one_transcript(): store = TranscriptMemoryStore() activity = Activity.create_message_activity() activity.text = "Activity 1" activity.channel_id = "Channel 1" - activity.conversation = ConversationAccount( id="Conversation 1" ) - - # Add one activity and make sure it's there and comes back + activity.conversation = ConversationAccount(id="Conversation 1") + + # Add one activity and make sure it's there and comes back await store.log_activity(activity) # Ask for the activity we just added - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] @@ -190,25 +219,28 @@ async def test_delete_one_transcript(): # Now delete the transcript await store.delete_transcript("Channel 1", "Conversation 1") - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] assert len(transcript) == 0 + @pytest.mark.asyncio async def test_delete_one_transcript_of_two(): store = TranscriptMemoryStore() - + activity = Activity.create_message_activity() activity.text = "Activity 1" activity.channel_id = "Channel 1" - activity.conversation = ConversationAccount( id="Conversation 1" ) + activity.conversation = ConversationAccount(id="Conversation 1") activity2 = Activity.create_message_activity() activity2.text = "Activity 2" activity2.channel_id = "Channel 2" - activity2.conversation = ConversationAccount( id="Conversation 1" ) + activity2.conversation = ConversationAccount(id="Conversation 1") - # Add one activity and make sure it's there and comes back + # Add one activity and make sure it's there and comes back await store.log_activity(activity) await store.log_activity(activity2) @@ -218,28 +250,33 @@ async def test_delete_one_transcript_of_two(): await store.delete_transcript("Channel 1", "Conversation 1") # Make sure the one we deleted is gone - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 1", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 1", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] assert len(transcript) == 0 # Make sure the other one is still there - transcriptAndContinuationToken = await store.get_transcript_activities("Channel 2", "Conversation 1") + transcriptAndContinuationToken = await store.get_transcript_activities( + "Channel 2", "Conversation 1" + ) transcript = transcriptAndContinuationToken[0] assert len(transcript) == 1 + @pytest.mark.asyncio async def test_list_transcripts(): store = TranscriptMemoryStore() - + activity = Activity.create_message_activity() activity.text = "Activity 1" activity.channel_id = "Channel 1" - activity.conversation = ConversationAccount( id="Conversation 1" ) + activity.conversation = ConversationAccount(id="Conversation 1") activity2 = Activity.create_message_activity() activity2.text = "Activity 2" activity2.channel_id = "Channel 2" - activity2.conversation = ConversationAccount( id="Conversation 1" ) + activity2.conversation = ConversationAccount(id="Conversation 1") # Make sure a list on an empty store returns an empty set transcriptAndContinuationToken = await store.list_transcripts("Should Be Empty") @@ -251,7 +288,7 @@ async def test_list_transcripts(): # Add one activity so we can go searching await store.log_activity(activity) - transcriptAndContinuationToken = await store.list_transcripts("Channel 1") + transcriptAndContinuationToken = await store.list_transcripts("Channel 1") transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert len(transcript) == 1 @@ -261,15 +298,15 @@ async def test_list_transcripts(): await store.log_activity(activity2) # Check again for "Transcript 1" which is on channel 1 - transcriptAndContinuationToken = await store.list_transcripts("Channel 1") + transcriptAndContinuationToken = await store.list_transcripts("Channel 1") transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert len(transcript) == 1 assert continuationToken is None # Check for "Transcript 2" which is on channel 2 - transcriptAndContinuationToken = await store.list_transcripts("Channel 2") + transcriptAndContinuationToken = await store.list_transcripts("Channel 2") transcript = transcriptAndContinuationToken[0] continuationToken = transcriptAndContinuationToken[1] assert len(transcript) == 1 - assert continuationToken is None \ No newline at end of file + assert continuationToken is None diff --git a/tests/hosting_core/test_turn_context.py b/tests/hosting_core/test_turn_context.py index 01305139..263b856f 100644 --- a/tests/hosting_core/test_turn_context.py +++ b/tests/hosting_core/test_turn_context.py @@ -1,3 +1,4 @@ +from annotated_types import T import pytest from typing import Callable, List