diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py index 5eed20d0..33fd7aa0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/channel_service_adapter.py @@ -341,6 +341,19 @@ async def process_proactive( await connector_client.close() await user_token_client.close() + def _resolve_if_connector_client_is_needed(self, activity: Activity) -> bool: + """Determine if a connector client is needed based on the activity's delivery mode and service URL. + + :param activity: The activity to evaluate. + :type activity: :class:`microsoft_agents.activity.Activity` + """ + if activity.delivery_mode in [ + DeliveryModes.expect_replies, + DeliveryModes.stream, + ]: + return False + return True + async def process_activity( self, claims_identity: ClaimsIdentity, @@ -403,21 +416,24 @@ async def process_activity( context.turn_state[self.USER_TOKEN_CLIENT_KEY] = user_token_client # Create the connector client to use for outbound requests. - connector_client: ConnectorClient = ( - await self._channel_service_client_factory.create_connector_client( - context, - claims_identity, - activity.service_url, - outgoing_audience, - scopes, - use_anonymous_auth_callback, + connector_client: Optional[ConnectorClient] = None + if self._resolve_if_connector_client_is_needed(activity): + connector_client = ( + await self._channel_service_client_factory.create_connector_client( + context, + claims_identity, + activity.service_url, + outgoing_audience, + scopes, + use_anonymous_auth_callback, + ) ) - ) - context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client + context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client await self.run_pipeline(context, callback) - await connector_client.close() + if connector_client: + await connector_client.close() await user_token_client.close() # If there are any results they will have been left on the TurnContext. diff --git a/tests/hosting_core/test_channel_service_adapter.py b/tests/hosting_core/test_channel_service_adapter.py index 777f7cef..ccf32b9e 100644 --- a/tests/hosting_core/test_channel_service_adapter.py +++ b/tests/hosting_core/test_channel_service_adapter.py @@ -1,8 +1,10 @@ import pytest from microsoft_agents.activity import ( + Activity, ConversationResourceResponse, ConversationParameters, + DeliveryModes, ) from microsoft_agents.hosting.core import ( ChannelServiceAdapter, @@ -14,6 +16,7 @@ TeamsConnectorClient, UserTokenClient, Connections, + ClaimsIdentity, ) from microsoft_agents.hosting.core.connector.conversations_base import ConversationsBase @@ -94,3 +97,152 @@ async def callback(context: TurnContext): assert context_arg.activity.conversation.id == "conversation123" assert context_arg.activity.channel_id == "channel_id" assert context_arg.activity.service_url == "service_url" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "delivery_mode, service_url", + [ + [DeliveryModes.expect_replies, None], + [DeliveryModes.stream, None], + [DeliveryModes.expect_replies, "https://service.url"], + [DeliveryModes.stream, "https://service.url"], + ], + ) + async def test_process_activity_expect_replies_and_stream( + self, mocker, user_token_client, adapter, delivery_mode, service_url + ): + user_token_client.get_access_token = mocker.AsyncMock( + return_value="user_token_value" + ) + adapter.run_pipeline = mocker.AsyncMock() + + async def callback(context: TurnContext): + return None + + activity = Activity( # type: ignore + type="message", + conversation={"id": "conversation123"}, + channel_id="channel_id", + delivery_mode=delivery_mode, + ) + activity.service_url = service_url + + claims_identity = ClaimsIdentity( + { + "aud": "agent_app_id", + "ver": "2.0", + "azp": "outgoing_app_id", + }, + is_authenticated=True, + ) + + await adapter.process_activity( + claims_identity, + activity, + callback, + ) + + adapter.run_pipeline.assert_awaited_once() + + context_arg, callback_arg = adapter.run_pipeline.call_args[0] + assert callback_arg == callback + assert context_arg.activity == activity + + assert context_arg.activity.conversation.id == "conversation123" + assert context_arg.activity.channel_id == "channel_id" + assert context_arg.activity.service_url == service_url + assert ( + context_arg.turn_state[ChannelServiceAdapter.USER_TOKEN_CLIENT_KEY] + is user_token_client + ) + assert ( + ChannelServiceAdapter._AGENT_CONNECTOR_CLIENT_KEY + not in context_arg.turn_state + ) + + @pytest.mark.asyncio + async def test_process_activity_normal_no_service_url( + self, mocker, user_token_client, adapter + ): + user_token_client.get_access_token = mocker.AsyncMock( + return_value="user_token_value" + ) + adapter.run_pipeline = mocker.AsyncMock() + + async def callback(context: TurnContext): + return None + + activity = Activity( # type: ignore + type="message", + conversation={"id": "conversation123"}, + channel_id="channel_id", + ) + + claims_identity = ClaimsIdentity( + { + "aud": "agent_app_id", + "ver": "2.0", + "azp": "outgoing_app_id", + }, + is_authenticated=True, + ) + + with pytest.raises(Exception) as exc_info: + await adapter.process_activity( + claims_identity, + activity, + callback, + ) + + @pytest.mark.asyncio + async def test_process_proactive( + self, mocker, user_token_client, connector_client, adapter + ): + user_token_client.get_access_token = mocker.AsyncMock( + return_value="user_token_value" + ) + adapter.run_pipeline = mocker.AsyncMock() + + async def callback(context: TurnContext): + return None + + activity = Activity( # type: ignore + type="message", + conversation={"id": "conversation123"}, + channel_id="channel_id", + service_url="service_url", + ) + + claims_identity = ClaimsIdentity( + { + "aud": "agent_app_id", + "ver": "2.0", + "azp": "outgoing_app_id", + }, + is_authenticated=True, + ) + + await adapter.process_proactive( + claims_identity, + activity, + "audience", + callback, + ) + + adapter.run_pipeline.assert_awaited_once() + + context_arg, callback_arg = adapter.run_pipeline.call_args[0] + assert callback_arg == callback + assert context_arg.activity == activity + + assert context_arg.activity.conversation.id == "conversation123" + assert context_arg.activity.channel_id == "channel_id" + assert context_arg.activity.service_url == "service_url" + assert ( + context_arg.turn_state[ChannelServiceAdapter.USER_TOKEN_CLIENT_KEY] + is user_token_client + ) + assert ( + context_arg.turn_state[ChannelServiceAdapter._AGENT_CONNECTOR_CLIENT_KEY] + is connector_client + )