Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,19 @@ async def process_proactive(
await connector_client.close()
await user_token_client.close()

def _resolve_if_connector_client_is_needed(self, activity: Activity) -> bool:
"""Determine if a connector client is needed based on the activity's delivery mode and service URL.

:param activity: The activity to evaluate.
:type activity: :class:`microsoft_agents.activity.Activity`
"""
if activity.delivery_mode in [
DeliveryModes.expect_replies,
DeliveryModes.stream,
]:
return False
return True

async def process_activity(
self,
claims_identity: ClaimsIdentity,
Expand Down Expand Up @@ -403,21 +416,24 @@ async def process_activity(
context.turn_state[self.USER_TOKEN_CLIENT_KEY] = user_token_client

# Create the connector client to use for outbound requests.
connector_client: ConnectorClient = (
await self._channel_service_client_factory.create_connector_client(
context,
claims_identity,
activity.service_url,
outgoing_audience,
scopes,
use_anonymous_auth_callback,
connector_client: Optional[ConnectorClient] = None
if self._resolve_if_connector_client_is_needed(activity):
connector_client = (
await self._channel_service_client_factory.create_connector_client(
context,
claims_identity,
activity.service_url,
outgoing_audience,
scopes,
use_anonymous_auth_callback,
)
)
)
context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client
context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client

await self.run_pipeline(context, callback)

await connector_client.close()
if connector_client:
await connector_client.close()
await user_token_client.close()

# If there are any results they will have been left on the TurnContext.
Expand Down
152 changes: 152 additions & 0 deletions tests/hosting_core/test_channel_service_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest

from microsoft_agents.activity import (
Activity,
ConversationResourceResponse,
ConversationParameters,
DeliveryModes,
)
from microsoft_agents.hosting.core import (
ChannelServiceAdapter,
Expand All @@ -14,6 +16,7 @@
TeamsConnectorClient,
UserTokenClient,
Connections,
ClaimsIdentity,
)

from microsoft_agents.hosting.core.connector.conversations_base import ConversationsBase
Expand Down Expand Up @@ -94,3 +97,152 @@ async def callback(context: TurnContext):
assert context_arg.activity.conversation.id == "conversation123"
assert context_arg.activity.channel_id == "channel_id"
assert context_arg.activity.service_url == "service_url"

@pytest.mark.asyncio
@pytest.mark.parametrize(
"delivery_mode, service_url",
[
[DeliveryModes.expect_replies, None],
[DeliveryModes.stream, None],
[DeliveryModes.expect_replies, "https://service.url"],
[DeliveryModes.stream, "https://service.url"],
],
)
async def test_process_activity_expect_replies_and_stream(
self, mocker, user_token_client, adapter, delivery_mode, service_url
):
user_token_client.get_access_token = mocker.AsyncMock(
return_value="user_token_value"
)
adapter.run_pipeline = mocker.AsyncMock()

async def callback(context: TurnContext):
return None

activity = Activity( # type: ignore
type="message",
conversation={"id": "conversation123"},
channel_id="channel_id",
delivery_mode=delivery_mode,
)
activity.service_url = service_url

claims_identity = ClaimsIdentity(
{
"aud": "agent_app_id",
"ver": "2.0",
"azp": "outgoing_app_id",
},
is_authenticated=True,
)

await adapter.process_activity(
claims_identity,
activity,
callback,
)

adapter.run_pipeline.assert_awaited_once()

context_arg, callback_arg = adapter.run_pipeline.call_args[0]
assert callback_arg == callback
assert context_arg.activity == activity

assert context_arg.activity.conversation.id == "conversation123"
assert context_arg.activity.channel_id == "channel_id"
assert context_arg.activity.service_url == service_url
assert (
context_arg.turn_state[ChannelServiceAdapter.USER_TOKEN_CLIENT_KEY]
is user_token_client
)
assert (
ChannelServiceAdapter._AGENT_CONNECTOR_CLIENT_KEY
not in context_arg.turn_state
)

@pytest.mark.asyncio
async def test_process_activity_normal_no_service_url(
self, mocker, user_token_client, adapter
):
user_token_client.get_access_token = mocker.AsyncMock(
return_value="user_token_value"
)
adapter.run_pipeline = mocker.AsyncMock()

async def callback(context: TurnContext):
return None

activity = Activity( # type: ignore
type="message",
conversation={"id": "conversation123"},
channel_id="channel_id",
)

claims_identity = ClaimsIdentity(
{
"aud": "agent_app_id",
"ver": "2.0",
"azp": "outgoing_app_id",
},
is_authenticated=True,
)

with pytest.raises(Exception) as exc_info:
await adapter.process_activity(
claims_identity,
activity,
callback,
)

@pytest.mark.asyncio
async def test_process_proactive(
self, mocker, user_token_client, connector_client, adapter
):
user_token_client.get_access_token = mocker.AsyncMock(
return_value="user_token_value"
)
adapter.run_pipeline = mocker.AsyncMock()

async def callback(context: TurnContext):
return None

activity = Activity( # type: ignore
type="message",
conversation={"id": "conversation123"},
channel_id="channel_id",
service_url="service_url",
)

claims_identity = ClaimsIdentity(
{
"aud": "agent_app_id",
"ver": "2.0",
"azp": "outgoing_app_id",
},
is_authenticated=True,
)

await adapter.process_proactive(
claims_identity,
activity,
"audience",
callback,
)

adapter.run_pipeline.assert_awaited_once()

context_arg, callback_arg = adapter.run_pipeline.call_args[0]
assert callback_arg == callback
assert context_arg.activity == activity

assert context_arg.activity.conversation.id == "conversation123"
assert context_arg.activity.channel_id == "channel_id"
assert context_arg.activity.service_url == "service_url"
assert (
context_arg.turn_state[ChannelServiceAdapter.USER_TOKEN_CLIENT_KEY]
is user_token_client
)
assert (
context_arg.turn_state[ChannelServiceAdapter._AGENT_CONNECTOR_CLIENT_KEY]
is connector_client
)