Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,28 +262,12 @@ async def create_conversation( # pylint: disable=arguments-differ
claims_identity = self.create_claims_identity(agent_app_id)
claims_identity.claims[AuthenticationConstants.SERVICE_URL_CLAIM] = service_url

# Create a turn context and run the pipeline.
context = self._create_turn_context(
claims_identity,
None,
callback,
)

# Create a UserTokenClient instance for the application to use. (For example, in the OAuthPrompt.)
user_token_client: UserTokenClient = (
await self._channel_service_client_factory.create_user_token_client(
context, claims_identity
)
)
context.turn_state[self.USER_TOKEN_CLIENT_KEY] = user_token_client

# Create the connector client to use for outbound requests.
connector_client: ConnectorClient = (
await self._channel_service_client_factory.create_connector_client(
context, claims_identity, service_url, audience
None, claims_identity, service_url, audience
)
)
context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client

# Make the actual create conversation call using the connector.
create_conversation_result = (
Expand All @@ -297,7 +281,22 @@ async def create_conversation( # pylint: disable=arguments-differ
create_conversation_result, channel_id, service_url, conversation_parameters
)

context.activity = create_activity
# Create a turn context and run the pipeline.
context = self._create_turn_context(
claims_identity,
None,
callback,
create_activity,
)
context.turn_state[self._AGENT_CONNECTOR_CLIENT_KEY] = connector_client

# Create a UserTokenClient instance for the application to use. (For example, in the OAuthPrompt.)
user_token_client: UserTokenClient = (
await self._channel_service_client_factory.create_user_token_client(
context, claims_identity
)
)
context.turn_state[self.USER_TOKEN_CLIENT_KEY] = user_token_client

# Run the pipeline
await self.run_pipeline(context, callback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ChannelServiceClientFactoryBase(Protocol):
@abstractmethod
async def create_connector_client(
self,
context: TurnContext,
context: TurnContext | None,
claims_identity: ClaimsIdentity,
service_url: str,
audience: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ async def _get_agentic_token(self, context: TurnContext, service_url: str) -> st

async def create_connector_client(
self,
context: TurnContext,
context: TurnContext | None,
claims_identity: ClaimsIdentity,
service_url: str,
audience: str,
scopes: Optional[list[str]] = None,
use_anonymous: bool = False,
) -> ConnectorClientBase:
if not context or not claims_identity:
raise TypeError("context and claims_identity are required")
if not claims_identity:
raise TypeError("claims_identity is required")
if not service_url:
raise TypeError(
"RestChannelServiceClientFactory.create_connector_client: service_url can't be None or Empty"
Expand All @@ -101,7 +101,7 @@ async def create_connector_client(
"RestChannelServiceClientFactory.create_connector_client: audience can't be None or Empty"
)

if context.activity.is_agentic_request():
if context and context.activity.is_agentic_request():
token = await self._get_agentic_token(context, service_url)
else:
token_provider: AccessTokenProviderBase = (
Expand Down
1 change: 1 addition & 0 deletions tests/_common/data/default_test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):
self.user_id = "__user_id"
self.bot_url = "https://botframework.com"
self.ms_app_id = "__ms_app_id"
self.service_url = "https://service.url/"

# Auth Handler Settings
self.abs_oauth_connection_name = "connection_name"
Expand Down
96 changes: 96 additions & 0 deletions tests/hosting_core/test_channel_service_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest

from microsoft_agents.activity import (
ConversationResourceResponse,
ConversationParameters,
)
from microsoft_agents.hosting.core import (
ChannelServiceAdapter,
TurnContext,
ConnectorClientBase,
UserTokenClientBase,
ChannelServiceClientFactoryBase,
RestChannelServiceClientFactory,
TeamsConnectorClient,
UserTokenClient,
Connections,
)

from microsoft_agents.hosting.core.connector.conversations_base import ConversationsBase


class MyChannelServiceAdapter(ChannelServiceAdapter):
pass


class TestChannelServiceAdapter:

@pytest.fixture
def connector_client(self, mocker):
connector_client = mocker.Mock(spec=TeamsConnectorClient)
mocker.patch.object(
TeamsConnectorClient, "__new__", return_value=connector_client
)
return connector_client

@pytest.fixture
def user_token_client(self, mocker):
user_token_client = mocker.Mock(spec=UserTokenClient)
mocker.patch.object(UserTokenClient, "__new__", return_value=user_token_client)
return user_token_client

@pytest.fixture
def connection_manager(self, mocker, user_token_client):
connection_manager = mocker.Mock(spec=Connections)
connection_manager.get_token_provider = mocker.Mock(
return_value=user_token_client
)
return connection_manager

@pytest.fixture
def factory(self, connection_manager):
client_factory = RestChannelServiceClientFactory(connection_manager)
return client_factory

@pytest.fixture
def adapter(self, factory):
return MyChannelServiceAdapter(factory)

@pytest.mark.asyncio
async def test_create_conversation_basic(
self, mocker, user_token_client, connector_client, adapter
):

user_token_client.get_access_token = mocker.AsyncMock(
return_value="user_token_value"
)
adapter.run_pipeline = mocker.AsyncMock()

connector_client.conversations = mocker.Mock(spec=ConversationsBase)
connector_client.conversations.create_conversation.return_value = (
ConversationResourceResponse(
activity_id="activity123",
service_url="https://service.url",
id="conversation123",
)
)

async def callback(context: TurnContext):
return None

await adapter.create_conversation(
"agent_app_id",
"channel_id",
"service_url",
"audience",
ConversationParameters(),
callback,
)

adapter.run_pipeline.assert_awaited_once()

context_arg, callback_arg = adapter.run_pipeline.call_args[0]
assert callback_arg == callback
assert context_arg.activity.conversation.id == "conversation123"
assert context_arg.activity.channel_id == "channel_id"
assert context_arg.activity.service_url == "service_url"
Loading