diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/__init__.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/__init__.py index 81bc1c83..b18336b7 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/__init__.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + from .agents_model import AgentsModel from .action_types import ActionTypes from .activity import Activity @@ -17,6 +20,8 @@ from .card_image import CardImage from .channels import Channels from .channel_account import ChannelAccount +from ._channel_id_field_mixin import _ChannelIdFieldMixin +from .channel_id import ChannelId from .conversation_account import ConversationAccount from .conversation_members import ConversationMembers from .conversation_parameters import ConversationParameters @@ -26,6 +31,7 @@ from .expected_replies import ExpectedReplies from .entity import ( Entity, + EntityTypes, AIEntity, ClientCitation, ClientCitationAppearance, @@ -36,6 +42,7 @@ SensitivityPattern, GeoCoordinates, Place, + ProductInfo, Thing, ) from .error import Error @@ -115,6 +122,8 @@ "CardImage", "Channels", "ChannelAccount", + "ChannelId", + "_ChannelIdFieldMixin", "ConversationAccount", "ConversationMembers", "ConversationParameters", @@ -145,6 +154,7 @@ "OAuthCard", "PagedMembersResult", "Place", + "ProductInfo", "ReceiptCard", "ReceiptItem", "ResourceResponse", diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/_channel_id_field_mixin.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/_channel_id_field_mixin.py new file mode 100644 index 00000000..77f60c02 --- /dev/null +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/_channel_id_field_mixin.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations + +import logging +from typing import Optional, Any + +from pydantic import ( + ModelWrapValidatorHandler, + SerializerFunctionWrapHandler, + computed_field, + model_validator, + model_serializer, +) + +from .channel_id import ChannelId + +logger = logging.getLogger(__name__) + + +# can be generalized in the future, if needed +class _ChannelIdFieldMixin: + """A mixin to add a computed field channel_id of type ChannelId to a Pydantic model.""" + + _channel_id: Optional[ChannelId] = None + + # required to define the setter below + @computed_field(return_type=Optional[ChannelId], alias="channelId") + @property + def channel_id(self) -> Optional[ChannelId]: + """Gets the _channel_id field""" + return self._channel_id + + # necessary for backward compatibility + # previously, channel_id was directly assigned with strings + @channel_id.setter + def channel_id(self, value: Any): + """Sets the channel_id after validating it as a ChannelId model.""" + if isinstance(value, ChannelId): + self._channel_id = value + elif isinstance(value, str): + self._channel_id = ChannelId(value) + else: + raise ValueError( + f"Invalid type for channel_id: {type(value)}. " + "Expected ChannelId or str." + ) + + def _set_validated_channel_id(self, data: Any) -> None: + """Sets the channel_id after validating it as a ChannelId model.""" + if "channelId" in data: + self.channel_id = data["channelId"] + elif "channel_id" in data: + self.channel_id = data["channel_id"] + + @model_validator(mode="wrap") + @classmethod + def _validate_channel_id( + cls, data: Any, handler: ModelWrapValidatorHandler + ) -> _ChannelIdFieldMixin: + """Validate the _channel_id field after model initialization. + + :return: The model instance itself. + """ + try: + model = handler(data) + model._set_validated_channel_id(data) + return model + except Exception: + logging.error("Model %s failed to validate with data %s", cls, data) + raise + + def _remove_serialized_unset_channel_id( + self, serialized: dict[str, object] + ) -> None: + """Remove the _channel_id field if it is not set.""" + if not self._channel_id: + if "channelId" in serialized: + del serialized["channelId"] + elif "channel_id" in serialized: + del serialized["channel_id"] + + @model_serializer(mode="wrap") + def _serialize_channel_id( + self, handler: SerializerFunctionWrapHandler + ) -> dict[str, object]: + """Serialize the model using Pydantic's standard serialization. + + :param handler: The serialization handler provided by Pydantic. + :return: A dictionary representing the serialized model. + """ + serialized = handler(self) + if self: # serialization can be called with None + self._remove_serialized_unset_channel_id(serialized) + return serialized diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py index b3189742..6b610579 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py @@ -1,10 +1,24 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations + +import logging from copy import copy from datetime import datetime, timezone -from typing import Optional -from pydantic import Field, SerializeAsAny +from typing import Optional, Any + +from pydantic import ( + Field, + SerializeAsAny, + model_serializer, + model_validator, + SerializerFunctionWrapHandler, + ModelWrapValidatorHandler, + computed_field, + ValidationError, +) + from .activity_types import ActivityTypes from .channel_account import ChannelAccount from .conversation_account import ConversationAccount @@ -14,9 +28,11 @@ from .attachment import Attachment from .entity import ( Entity, + EntityTypes, Mention, AIEntity, ClientCitation, + ProductInfo, SensitivityUsageInfo, ) from .conversation_reference import ConversationReference @@ -24,12 +40,16 @@ from .semantic_action import SemanticAction from .agents_model import AgentsModel from .role_types import RoleTypes +from ._channel_id_field_mixin import _ChannelIdFieldMixin +from .channel_id import ChannelId from ._model_utils import pick_model, SkipNone from ._type_aliases import NonEmptyString +logger = logging.getLogger(__name__) + # TODO: A2A Agent 2 is responding with None as id, had to mark it as optional (investigate) -class Activity(AgentsModel): +class Activity(AgentsModel, _ChannelIdFieldMixin): """An Activity is the basic communication type for the protocol. :param type: Contains the activity type. Possible values include: @@ -50,8 +70,8 @@ class Activity(AgentsModel): :type local_timezone: str :param service_url: Contains the URL that specifies the channel's service endpoint. Set by the channel. :type service_url: str - :param channel_id: Contains an ID that uniquely identifies the channel. Set by the channel. - :type channel_id: str + :param channel_id: Contains an ID that uniquely identifies the channel (and possibly the sub-channel). Set by the channel. + :type channel_id: ~microsoft_agents.activity.ChannelId :param from_property: Identifies the sender of the message. :type from_property: ~microsoft_agents.activity.ChannelAccount :param conversation: Identifies the conversation to which the activity belongs. @@ -136,7 +156,6 @@ class Activity(AgentsModel): local_timestamp: datetime = None local_timezone: NonEmptyString = None service_url: NonEmptyString = None - channel_id: NonEmptyString = None from_property: ChannelAccount = Field(None, alias="from") conversation: ConversationAccount = None recipient: ChannelAccount = None @@ -173,6 +192,92 @@ class Activity(AgentsModel): semantic_action: SemanticAction = None caller_id: NonEmptyString = None + @model_validator(mode="wrap") + @classmethod + def _validate_channel_id( + cls, data: Any, handler: ModelWrapValidatorHandler[Activity] + ) -> Activity: + """Validate the Activity, ensuring consistency between channel_id.sub_channel and productInfo entity. + + :param data: The input data to validate. + :param handler: The validation handler provided by Pydantic. + :return: The validated Activity instance. + """ + try: + # run Pydantic's standard validation first + activity = handler(data) + + # needed to assign to a computed field + # needed because we override the mixin validator + activity._set_validated_channel_id(data) + + # sync sub_channel with productInfo entity + product_info = activity.get_product_info_entity() + if product_info and activity.channel_id: + if ( + activity.channel_id.sub_channel + and activity.channel_id.sub_channel != product_info.id + ): + raise Exception( + "Conflict between channel_id.sub_channel and productInfo entity" + ) + activity.channel_id = ChannelId( + channel=activity.channel_id.channel, + sub_channel=product_info.id, + ) + + return activity + except ValidationError as exc: + logger.error("Validation error for Activity: %s", exc, exc_info=True) + raise + + @model_serializer(mode="wrap") + def _serialize_sub_channel_data( + self, handler: SerializerFunctionWrapHandler + ) -> dict[str, object]: + """Serialize the Activity, ensuring consistency between channel_id.sub_channel and productInfo entity. + + :param handler: The serialization handler provided by Pydantic. + :return: A dictionary representing the serialized Activity. + """ + + # run Pydantic's standard serialization first + serialized = handler(self) + if not self: # serialization can be called with None + return serialized + + # find the ProductInfo entity + product_info = None + for i, entity in enumerate(serialized.get("entities") or []): + if entity.get("type", "") == EntityTypes.PRODUCT_INFO: + product_info = entity + break + + # maintain consistency between ProductInfo entity and sub channel + if self.channel_id and self.channel_id.sub_channel: + if product_info and product_info.get("id") != self.channel_id.sub_channel: + raise Exception( + "Conflict between channel_id.sub_channel and productInfo entity" + ) + elif not product_info: + if not serialized.get("entities"): + serialized["entities"] = [] + serialized["entities"].append( + { + "type": EntityTypes.PRODUCT_INFO, + "id": self.channel_id.sub_channel, + } + ) + elif product_info: # remove productInfo entity if sub_channel is not set + del serialized["entities"][i] + if not serialized["entities"]: # after removal above, list may be empty + del serialized["entities"] + + # necessary due to computed_field serialization + self._remove_serialized_unset_channel_id(serialized) + + return serialized + def apply_conversation_reference( self, reference: ConversationReference, is_incoming: bool = False ): @@ -531,6 +636,14 @@ def get_conversation_reference(self) -> ConversationReference: service_url=self.service_url, ) + def get_product_info_entity(self) -> Optional[ProductInfo]: + if not self.entities: + return None + target = EntityTypes.PRODUCT_INFO.lower() + # validated entities can be Entity, and that prevents us from + # making assumptions about the casing of the 'type' attribute + return next(filter(lambda e: e.type.lower() == target, self.entities), None) + def get_mentions(self) -> list[Mention]: """ Resolves the mentions from the entities of this activity. @@ -543,7 +656,7 @@ def get_mentions(self) -> list[Mention]: """ if not self.entities: return [] - return [x for x in self.entities if x.type.lower() == "mention"] + return [x for x in self.entities if x.type.lower() == EntityTypes.MENTION] def get_reply_conversation_reference( self, reply: ResourceResponse diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_id.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_id.py new file mode 100644 index 00000000..ad4890fe --- /dev/null +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/channel_id.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Optional, Any + +from pydantic_core import CoreSchema, core_schema +from pydantic import GetCoreSchemaHandler + + +class ChannelId(str): + """A ChannelId represents a channel and optional sub-channel in the format 'channel:sub_channel'.""" + + def __init__( + self, + value: Optional[str] = None, + *, + channel: Optional[str] = None, + sub_channel: Optional[str] = None, + ) -> None: + """Initialize a ChannelId instance. + + :param value: The full channel ID string in the format 'channel:sub_channel'. Must be provided if channel is not provided. + :param channel: The main channel string. Must be provided if value is not provided. + :param sub_channel: The sub-channel string. + :raises ValueError: If the input parameters are invalid. value and channel cannot both be provided. + """ + super().__init__() + if not channel: + split = self.strip().split(":", 1) + self._channel = split[0].strip() + self._sub_channel = split[1].strip() if len(split) == 2 else None + else: + self._channel = channel + self._sub_channel = sub_channel + + def __new__( + cls, + value: Optional[str] = None, + *, + channel: Optional[str] = None, + sub_channel: Optional[str] = None, + ) -> ChannelId: + """Create a new ChannelId instance. + + :param value: The full channel ID string in the format 'channel:sub_channel'. Must be provided if channel is not provided. + :param channel: The main channel string. Must be provided if value is not provided. Must not contain ':', as it delimits channels and sub channels. + :param sub_channel: The sub-channel string. + :return: A new ChannelId instance. + :raises ValueError: If the input parameters are invalid. value and channel cannot both be provided. + """ + if isinstance(value, str): + if channel or sub_channel: + raise ValueError( + "If value is provided, channel and sub_channel must be None" + ) + + value = value.strip() + if value: + return str.__new__(cls, value) + raise TypeError("value must be a non empty string if provided") + else: + if ( + not isinstance(channel, str) + or len(channel.strip()) == 0 + or ":" in channel + ): + raise TypeError( + "channel must be a non empty string, and must not contain the ':' character" + ) + if sub_channel is not None and (not isinstance(sub_channel, str)): + raise TypeError("sub_channel must be a string if provided") + channel = channel.strip() + sub_channel = sub_channel.strip() if sub_channel else None + if sub_channel: + return str.__new__(cls, f"{channel}:{sub_channel}") + return str.__new__(cls, channel) + + @property + def channel(self) -> str: + """The main channel, e.g. 'email' in 'email:work'.""" + return self._channel # type: ignore[return-value] + + @property + def sub_channel(self) -> Optional[str]: + """The sub-channel, e.g. 'work' in 'email:work'. May be None.""" + return self._sub_channel + + # https://docs.pydantic.dev/dev/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + return core_schema.no_info_after_validator_function(cls, handler(str)) diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/conversation_reference.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/conversation_reference.py index 0b680ea3..4ec1b4a8 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/conversation_reference.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/conversation_reference.py @@ -1,20 +1,27 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations + from uuid import uuid4 as uuid from typing import Optional +import logging from pydantic import Field from .channel_account import ChannelAccount +from ._channel_id_field_mixin import _ChannelIdFieldMixin +from .channel_id import ChannelId from .conversation_account import ConversationAccount from .agents_model import AgentsModel from ._type_aliases import NonEmptyString from .activity_types import ActivityTypes from .activity_event_names import ActivityEventNames +logger = logging.getLogger(__name__) + -class ConversationReference(AgentsModel): +class ConversationReference(AgentsModel, _ChannelIdFieldMixin): """An object relating to a particular point in a conversation. :param activity_id: (Optional) ID of the activity to refer to @@ -26,7 +33,7 @@ class ConversationReference(AgentsModel): :param conversation: Conversation reference :type conversation: ~microsoft_agents.activity.ConversationAccount :param channel_id: Channel ID - :type channel_id: str + :type channel_id: ~microsoft_agents.activity.ChannelId :param locale: A locale name for the contents of the text field. The locale name is a combination of an ISO 639 two- or three-letter culture code associated with a language and an ISO 3166 two-letter @@ -43,7 +50,6 @@ class ConversationReference(AgentsModel): user: Optional[ChannelAccount] = None agent: ChannelAccount = Field(None, alias="bot") conversation: ConversationAccount - channel_id: NonEmptyString locale: Optional[NonEmptyString] = None service_url: NonEmptyString = None diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/__init__.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/__init__.py index b35460d8..42fe69fd 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/__init__.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/__init__.py @@ -1,5 +1,9 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + from .mention import Mention from .entity import Entity +from .entity_types import EntityTypes from .ai_entity import ( ClientCitation, ClientCitationAppearance, @@ -11,10 +15,12 @@ ) from .geo_coordinates import GeoCoordinates from .place import Place +from .product_info import ProductInfo from .thing import Thing __all__ = [ "Entity", + "EntityTypes", "AIEntity", "ClientCitation", "ClientCitationAppearance", @@ -25,5 +31,6 @@ "SensitivityPattern", "GeoCoordinates", "Place", + "ProductInfo", "Thing", ] diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity.py index e7352609..74b35142 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity.py @@ -1,14 +1,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import Any, Optional -from enum import Enum +from typing import Any from pydantic import model_serializer, model_validator from pydantic.alias_generators import to_camel, to_snake from ..agents_model import AgentsModel, ConfigDict -from .._type_aliases import NonEmptyString class Entity(AgentsModel): diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity_types.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity_types.py new file mode 100644 index 00000000..4af74397 --- /dev/null +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/entity_types.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from enum import Enum + + +class EntityTypes(str, Enum): + """Well-known enumeration of entity types.""" + + GEO_COORDINATES = "GeoCoordinates" + MENTION = "mention" + PLACE = "Place" + THING = "Thing" + PRODUCT_INFO = "ProductInfo" diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/geo_coordinates.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/geo_coordinates.py index 15a9e0a5..1f758a5c 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/geo_coordinates.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/geo_coordinates.py @@ -1,11 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from ..agents_model import AgentsModel +from typing import Literal + from .._type_aliases import NonEmptyString +from .entity import Entity +from .entity_types import EntityTypes -class GeoCoordinates(AgentsModel): +class GeoCoordinates(Entity): """GeoCoordinates (entity type: "https://schema.org/GeoCoordinates"). :param elevation: Elevation of the location [WGS @@ -26,5 +29,5 @@ class GeoCoordinates(AgentsModel): elevation: float = None latitude: float = None longitude: float = None - type: NonEmptyString = None + type: Literal[EntityTypes.GEO_COORDINATES] = EntityTypes.GEO_COORDINATES name: NonEmptyString = None diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/mention.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/mention.py index ce8b5084..1223c045 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/mention.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/mention.py @@ -5,7 +5,7 @@ from ..channel_account import ChannelAccount from .entity import Entity -from .._type_aliases import NonEmptyString +from .entity_types import EntityTypes class Mention(Entity): @@ -21,4 +21,4 @@ class Mention(Entity): mentioned: ChannelAccount = None text: str = None - type: Literal["mention"] = "mention" + type: Literal[EntityTypes.MENTION] = EntityTypes.MENTION diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/place.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/place.py index ea0b347f..fa179d91 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/place.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/place.py @@ -1,11 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from ..agents_model import AgentsModel +from typing import Literal + from .._type_aliases import NonEmptyString +from .entity import Entity +from .entity_types import EntityTypes -class Place(AgentsModel): +class Place(Entity): """Place (entity type: "https://schema.org/Place"). :param address: Address of the place (may be `string` or complex object of @@ -26,5 +29,5 @@ class Place(AgentsModel): address: object = None geo: object = None has_map: object = None - type: NonEmptyString = None + type: Literal[EntityTypes.PLACE] = EntityTypes.PLACE name: NonEmptyString = None diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/product_info.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/product_info.py new file mode 100644 index 00000000..17bbc091 --- /dev/null +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/product_info.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import Literal + +from .entity import Entity +from .entity_types import EntityTypes + + +class ProductInfo(Entity): + """Product information (entity type: "productInfo"). + + :param type: The type of the entity, always "productInfo". + :type type: str + :param id: The unique identifier for the product. + :type id: str + """ + + type: Literal[EntityTypes.PRODUCT_INFO] = EntityTypes.PRODUCT_INFO + id: str = None diff --git a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/thing.py b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/thing.py index 040b850a..73d28de4 100644 --- a/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/thing.py +++ b/libraries/microsoft-agents-activity/microsoft_agents/activity/entity/thing.py @@ -1,11 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from ..agents_model import AgentsModel +from typing import Literal + from .._type_aliases import NonEmptyString +from .entity import Entity +from .entity_types import EntityTypes -class Thing(AgentsModel): +class Thing(Entity): """Thing (entity type: "https://schema.org/Thing"). :param type: The type of the thing @@ -14,5 +17,5 @@ class Thing(AgentsModel): :type name: str """ - type: NonEmptyString = None + type: Literal[EntityTypes.THING] = EntityTypes.THING name: NonEmptyString = None diff --git a/libraries/microsoft-agents-activity/pyproject.toml b/libraries/microsoft-agents-activity/pyproject.toml index e97bea23..0c6938ab 100644 --- a/libraries/microsoft-agents-activity/pyproject.toml +++ b/libraries/microsoft-agents-activity/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "pydantic>=2.10.4", + "pydantic>=2.10.4" ] [project.urls] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py index f39e8428..a3320b4b 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/turn_context.py @@ -18,6 +18,7 @@ ResourceResponse, DeliveryModes, ) +from microsoft_agents.activity.entity.entity_types import EntityTypes from microsoft_agents.hosting.core.authorization.claims_identity import ClaimsIdentity @@ -428,7 +429,7 @@ def get_mentions(activity: Activity) -> list[Mention]: result: list[Mention] = [] if activity.entities is not None: for entity in activity.entities: - if entity.type.lower() == "mention": + if entity.type.lower() == EntityTypes.MENTION: result.append(entity) return result diff --git a/tests/activity/pydantic/__init__.py b/tests/activity/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/activity/pydantic/test_activity_io.py b/tests/activity/pydantic/test_activity_io.py new file mode 100644 index 00000000..d60b11c3 --- /dev/null +++ b/tests/activity/pydantic/test_activity_io.py @@ -0,0 +1,260 @@ +import pytest + +from pydantic import ValidationError + +from microsoft_agents.activity import ( + Activity, + ChannelId, + Entity, + EntityTypes, + ProductInfo, + ConversationReference, + ConversationAccount, +) + + +# validation / serialization tests +class TestActivityIO: + + def test_serialize_basic(self): + activity = Activity(type="message") + activity_copy = Activity( + **activity.model_dump(mode="json", exclude_unset=True, by_alias=True) + ) + assert activity_copy == activity + + @pytest.mark.parametrize( + "data, expected", + [ + ( + "msteams:subchannel", + ChannelId(channel="msteams", sub_channel="subchannel"), + ), + ("msteams/subchannel", ChannelId(channel="msteams/subchannel")), + ("channel:sub", ChannelId(channel="channel", sub_channel="sub")), + ( + ChannelId(channel="msteams", sub_channel="subchannel"), + ChannelId(channel="msteams", sub_channel="subchannel"), + ), + (ChannelId(channel="msteams"), ChannelId(channel="msteams")), + ], + ) + def test_channel_id_setter_validation(self, data, expected): + activity = Activity(type="message") + activity.channel_id = data + + assert activity.channel_id == expected + assert isinstance(activity.channel_id, ChannelId) + if not isinstance(data, dict): + assert activity.channel_id == data + + def test_channel_id_setter_validation_error(self): + activity = Activity(type="message") + with pytest.raises(Exception): + activity.channel_id = {} + with pytest.raises(Exception): + activity.channel_id = 123 + + def test_channel_id_validate_without_product_info(self): + data = {"type": "message", "channel_id": "msteams:subchannel"} + activity = Activity(**data) + assert activity.channel_id == ChannelId( + channel="msteams", sub_channel="subchannel" + ) + assert not activity.get_product_info_entity() + + @pytest.mark.parametrize( + "data, data_with_alias, expected", + [ + [ + { + "type": "message", + "channel_id": "parent:misc", + "entities": [{"type": "some_entity"}], + }, + { + "type": "message", + "channelId": "parent:misc", + "entities": [{"type": "some_entity"}], + }, + Activity( + type="message", + channel_id="parent:misc", + entities=[Entity(type="some_entity")], + ), + ], + [ + { + "type": "message", + "channel_id": "parent", + "entities": [ + {"type": "some_entity"}, + {"type": EntityTypes.PRODUCT_INFO, "id": "misc"}, + ], + }, + { + "type": "message", + "channelId": "parent", + "entities": [ + {"type": "some_entity"}, + {"type": EntityTypes.PRODUCT_INFO, "id": "misc"}, + ], + }, + Activity( + type="message", + channel_id="parent:misc", + entities=[ + Entity(type="some_entity"), + Entity(type=EntityTypes.PRODUCT_INFO, id="misc"), + ], + ), + ], + ], + ) + def test_channel_id_sub_channel_changed_with_product_info( + self, data, data_with_alias, expected + ): + activity = Activity(**data) + activity_from_alias = Activity(**data_with_alias) + assert activity == expected + assert activity_from_alias == expected + assert activity.model_copy() == activity_from_alias.model_copy() + + def test_channel_id_sub_channel_conflict_on_validation(self): + with pytest.raises(Exception): + activity = Activity( + type="message", + channel_id="parent:misc", + entities=[Entity(type="some_type"), ProductInfo(id="sub_channel")], + ) + + def test_channel_id_unset_becomes_set_at_init(self): + activity = Activity(type="message") + activity.channel_id = "channel:sub_channel" + data = activity.model_dump(mode="json", exclude_unset=True, by_alias=True) + assert data["channelId"] == "channel:sub_channel" + + def test_channel_id_unset_at_init_not_included(self): + activity = Activity(type="message") + data = activity.model_dump(mode="json", exclude_unset=True, by_alias=True) + assert "channelId" not in data + + def test_product_info_avoids_error_no_parent_channel(self): + activity = Activity(type="message", entities=[ProductInfo(id="sub_channel")]) + assert activity.channel_id is None + + @pytest.mark.parametrize( + "activity, expected, expected_no_alias", + [ + [Activity(type="message"), {"type": "message"}, {"type": "message"}], + [ + Activity(type="message", channel_id="msteams"), + {"type": "message", "channelId": "msteams"}, + {"type": "message", "channel_id": "msteams"}, + ], + [ + Activity(type="message", channel_id="msteams:subchannel"), + { + "type": "message", + "channelId": "msteams:subchannel", + "entities": [ + {"type": EntityTypes.PRODUCT_INFO.value, "id": "subchannel"} + ], + }, + { + "type": "message", + "channel_id": "msteams:subchannel", + "entities": [ + {"type": EntityTypes.PRODUCT_INFO.value, "id": "subchannel"} + ], + }, + ], + [ + Activity( + type="message", + channel_id="msteams:subchannel", + entities=[Entity(type="other")], + ), + { + "type": "message", + "channelId": "msteams:subchannel", + "entities": [ + {"type": "other"}, + {"type": EntityTypes.PRODUCT_INFO.value, "id": "subchannel"}, + ], + }, + { + "type": "message", + "channel_id": "msteams:subchannel", + "entities": [ + {"type": "other"}, + {"type": EntityTypes.PRODUCT_INFO.value, "id": "subchannel"}, + ], + }, + ], + [ + Activity( + type="message", + channel_id="msteams:misc", + entities=[{"type": "other"}, ProductInfo(id="misc")], + ), + { + "type": "message", + "channelId": "msteams:misc", + "entities": [ + {"type": "other"}, + {"type": EntityTypes.PRODUCT_INFO.value, "id": "misc"}, + ], + }, + { + "type": "message", + "channel_id": "msteams:misc", + "entities": [ + {"type": "other"}, + {"type": EntityTypes.PRODUCT_INFO.value, "id": "misc"}, + ], + }, + ], + [ + Activity(type="message", entities=[ProductInfo(id="misc")]), + {"type": "message"}, + {"type": "message"}, + ], + ], + ) + def test_serialize(self, activity, expected, expected_no_alias): + data = activity.model_dump(mode="json", exclude_unset=True, by_alias=True) + data_no_alias = activity.model_dump(exclude_unset=True, by_alias=False) + assert data == expected + assert data_no_alias == expected_no_alias + + def test_model_dump(self): + activity = Activity(type="message") + data = activity.model_dump(exclude_unset=True) + assert data == {"type": "message"} + + def test_serialize_misconfiguration_no_sub_channel(self): + activity = Activity( + type="message", channel_id="msteams", entities=[{"type": "other"}] + ) + activity.entities.append(ProductInfo(id="sub_channel")) + + data = activity.model_dump(mode="json", exclude_unset=True, by_alias=True) + assert data == { + "type": "message", + "channelId": "msteams", + "entities": [ + {"type": "other"}, + ], + } + + def test_serialize_sub_channel_conflict(self): + activity = Activity( + type="message", + channel_id="msteams:subchannel", + entities=[{"type": "other"}], + ) + activity.entities.append(ProductInfo(id="other_sub_channel")) + + with pytest.raises(Exception): + activity.model_dump(mode="json", exclude_unset=True, by_alias=True) diff --git a/tests/activity/pydantic/test_channel_id_field_mixin.py b/tests/activity/pydantic/test_channel_id_field_mixin.py new file mode 100644 index 00000000..ab51b056 --- /dev/null +++ b/tests/activity/pydantic/test_channel_id_field_mixin.py @@ -0,0 +1,82 @@ +import pytest + +from typing import Optional +from pydantic import BaseModel, ValidationError + +from microsoft_agents.activity import ChannelId, _ChannelIdFieldMixin + + +class DummyModel(BaseModel, _ChannelIdFieldMixin): ... + + +def channel_id_eq(a: Optional[ChannelId], b: Optional[ChannelId]) -> bool: + return a.channel == b.channel and a.sub_channel == b.sub_channel and a == b + + +class TestChannelIdFieldMixin: + + def test_validation_basic(self): + model = DummyModel(channel_id="email:support") + assert channel_id_eq(model.channel_id, ChannelId("email:support")) + model = DummyModel(channel_id="email") + assert channel_id_eq(model.channel_id, ChannelId("email")) + model = DummyModel(channel_id="channel:sub_channel:extra") + assert channel_id_eq(model.channel_id, ChannelId("channel:sub_channel:extra")) + + def test_validation_from_channel_id(self): + model = DummyModel(channel_id=ChannelId("email:support")) + assert channel_id_eq(model.channel_id, ChannelId("email:support")) + + def test_validation_dict(self): + model = DummyModel.model_validate({"channelId": "email:support"}) + assert channel_id_eq(model.channel_id, ChannelId("email:support")) + + def test_validation_dict_camel_case(self): + model = DummyModel.model_validate({"channel_id": "email:support"}) + assert channel_id_eq(model.channel_id, ChannelId("email:support")) + + def test_validation_none(self): + model = DummyModel.model_validate({}) + assert model.channel_id is None + + def test_validation_invalid_type(self): + with pytest.raises(ValidationError): + DummyModel.model_validate({"channelId": 123}) + with pytest.raises(ValidationError): + DummyModel.model_validate({"channel_id": 123}) + with pytest.raises(ValidationError): + DummyModel.model_validate({"channelId": None}) + with pytest.raises(ValidationError): + DummyModel(channel_id=123) + + def test_serialize(self): + model = DummyModel(channel_id="email:support") + assert model.model_dump() == {"channel_id": "email:support"} + assert model.model_dump_json() == '{"channel_id":"email:support"}' + assert model.model_dump(by_alias=True) == {"channelId": "email:support"} + assert model.model_dump_json(by_alias=True) == '{"channelId":"email:support"}' + assert model.model_dump(exclude_unset=True) == {"channel_id": "email:support"} + + def test_serialize_none(self): + model = DummyModel() + assert model.model_dump() == {} + assert model.model_dump_json() == "{}" + assert model.model_dump(by_alias=True) == {} + assert model.model_dump_json(by_alias=True) == "{}" + assert model.model_dump(exclude_unset=True) == {} + + def test_set(self): + model = DummyModel() + assert model.channel_id is None + model.channel_id = "email:support" + assert channel_id_eq(model.channel_id, ChannelId("email:support")) + model.channel_id = "a:b:c" + assert channel_id_eq(model.channel_id, ChannelId("a:b:c")) + model.channel_id = ChannelId("email") + assert channel_id_eq(model.channel_id, ChannelId("email")) + with pytest.raises(Exception): + model.channel_id = 123 + with pytest.raises(Exception): + model.channel_id = "" + with pytest.raises(Exception): + model.channel_id = None diff --git a/tests/activity/test_activity.py b/tests/activity/test_activity.py index 40d695fe..a6874186 100644 --- a/tests/activity/test_activity.py +++ b/tests/activity/test_activity.py @@ -16,6 +16,7 @@ AIEntity, Place, Thing, + ProductInfo, RoleTypes, ) @@ -352,12 +353,6 @@ def test_is_from_streaming_connection(self, service_url, expected): activity = Activity(type="message", service_url=service_url) assert activity.is_from_streaming_connection() == expected - def test_serialize_basic(self, activity): - activity_copy = Activity( - **activity.model_dump(mode="json", exclude_unset=True, by_alias=True) - ) - assert activity_copy == activity - def test_get_mentions(self): activity = Activity( type="message", @@ -373,6 +368,52 @@ def test_get_mentions(self): Entity(type="mention", text="Another mention"), ] + @pytest.mark.parametrize( + "entities, expected", + [ + [ + [ + Entity( + type="ProductInfo", + id="product_123", + ), + Entity(type="other"), + Entity(type="mention", text="Another mention"), + ], + Entity( + type="ProductInfo", + id="product_123", + ), + ], + [ + [ + Entity(type="other"), + Entity(type="mention", text="Another mention"), + ], + None, + ], + [ + [ + Entity( + type="ProductInfo", + id="product_123", + ), + Entity( + type="ProductInfo", + id="product_456", + ), + Entity(type="mention", text="Another mention"), + ], + Entity(type="ProductInfo", id="product_123"), + ], + [[], None], + ], + ) + def test_get_product_info_entity_single(self, entities, expected): + activity = Activity(type="message", entities=entities) + retrieved_product_info = activity.get_product_info_entity() + assert retrieved_product_info == expected + class TestActivityAgenticOps: diff --git a/tests/activity/test_channel_id.py b/tests/activity/test_channel_id.py new file mode 100644 index 00000000..ef592243 --- /dev/null +++ b/tests/activity/test_channel_id.py @@ -0,0 +1,61 @@ +import pytest + +from microsoft_agents.activity import ChannelId + +from tests._common.data import TEST_DEFAULTS + +DEFAULTS = TEST_DEFAULTS() + + +class TestChannelId: + + def test_init_from_str(self): + channel_id = ChannelId("email:support") + assert channel_id.channel == "email" + assert channel_id.sub_channel == "support" + assert str(channel_id) == "email:support" + assert channel_id == "email:support" + assert channel_id in ["email:support", "other"] + assert channel_id not in ["email:other", "other"] + assert channel_id != "email:other" + assert channel_id in ["wow", ChannelId("email:support")] + assert channel_id == ChannelId("email:support") + + def test_init_multiple_colons(self): + assert ChannelId("email:support:extra").channel == "email" + assert ChannelId("email:support:extra").sub_channel == "support:extra" + + def test_init_multiple_args(self): + with pytest.raises(ValueError): + ChannelId("email:support", channel="a", sub_channel="b") + + def test_init_from_parts(self): + channel_id = ChannelId(channel="email", sub_channel="support") + assert channel_id.channel == "email" + assert channel_id.sub_channel == "support" + assert str(channel_id) == "email:support" + + channel_id2 = ChannelId(channel="email") + assert channel_id2.channel == "email" + assert channel_id2.sub_channel is None + assert str(channel_id2) == "email" + + def test_init_errors(self): + with pytest.raises(Exception): + ChannelId(channel="email", sub_channel=123) + with pytest.raises(Exception): + ChannelId(channel="", sub_channel="support") + with pytest.raises(Exception): + ChannelId("") + with pytest.raises(Exception): + ChannelId() + with pytest.raises(Exception): + ChannelId(channel=None) + with pytest.raises(Exception): + ChannelId(sub_channel="sub_channel") + with pytest.raises(Exception): + ChannelId(" \t\n ") + with pytest.raises(Exception): + ChannelId("", channel=" ", sub_channel="support") + with pytest.raises(Exception): + ChannelId(channel="a:b", sub_channel="support") diff --git a/tests/activity/test_sub_channels.py b/tests/activity/test_sub_channels.py new file mode 100644 index 00000000..67c6d92c --- /dev/null +++ b/tests/activity/test_sub_channels.py @@ -0,0 +1,5 @@ +from microsoft_agents.activity import Activity, ChannelId, Entity + + +class TestSubChannels: + pass diff --git a/tests/hosting_core/_oauth/test_oauth_flow.py b/tests/hosting_core/_oauth/test_oauth_flow.py index 8e9681a1..a4684411 100644 --- a/tests/hosting_core/_oauth/test_oauth_flow.py +++ b/tests/hosting_core/_oauth/test_oauth_flow.py @@ -33,10 +33,13 @@ def create_testing_Activity( text="a", ): # mock_conversation_ref = mocker.MagicMock(ConversationReference) + conversation_reference = ConversationReference( + conversation={"id": "conv1"}, + ) mocker.patch.object( Activity, "get_conversation_reference", - return_value=mocker.MagicMock(ConversationReference), + return_value=conversation_reference, ) return Activity( type=type, @@ -44,7 +47,7 @@ def create_testing_Activity( from_property=ChannelAccount(id=DEFAULTS.user_id), channel_id=DEFAULTS.channel_id, # get_conversation_reference=mocker.Mock(return_value=conv_ref), - relates_to=mocker.MagicMock(ConversationReference), + relates_to=conversation_reference, value=value, text=text, )