From 06d6f88ef883046ff76b1e49258504483b638471 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 30 Jan 2026 19:06:25 +0000 Subject: [PATCH 1/5] Add ability to mark the source of Agent request messages and use that for filtering --- .../Program.cs | 4 +- .../Program.cs | 4 +- .../Program.cs | 6 +- .../Program.cs | 6 +- .../AIContextProvider.cs | 90 +++- .../AgentRequestMessageSource.cs | 106 ++++ .../ChatHistoryProvider.cs | 111 +++- .../ChatHistoryProviderExtensions.cs | 7 +- .../ChatHistoryProviderMessageFilter.cs | 4 +- .../ChatMessageExtensions.cs | 27 + .../InMemoryChatHistoryProvider.cs | 10 +- .../CosmosChatHistoryProvider.cs | 6 +- .../Microsoft.Agents.AI.Mem0/Mem0Provider.cs | 15 +- .../WorkflowChatHistoryProvider.cs | 6 +- .../ChatClient/ChatClientAgent.cs | 52 +- .../Memory/ChatHistoryMemoryProvider.cs | 10 +- .../Microsoft.Agents.AI/TextSearchProvider.cs | 9 +- .../AIContextProviderTests.cs | 6 +- .../AgentRequestMessageSourceTests.cs | 489 ++++++++++++++++++ .../ChatHistoryProviderExtensionsTests.cs | 37 +- .../ChatHistoryProviderMessageFilterTests.cs | 39 +- .../ChatHistoryProviderTests.cs | 4 +- .../ChatMessageExtensionsTests.cs | 197 +++++++ .../InMemoryChatHistoryProviderTests.cs | 18 +- ...AzureAIProjectChatClientExtensionsTests.cs | 6 +- .../CosmosChatHistoryProviderTests.cs | 36 +- .../Mem0ProviderTests.cs | 6 +- .../Mem0ProviderTests.cs | 8 +- .../ChatClient/ChatClientAgentTests.cs | 95 ++-- ...hatClientAgent_BackgroundResponsesTests.cs | 65 ++- ...tClientAgent_ChatHistoryManagementTests.cs | 108 ++-- .../Data/TextSearchProviderTests.cs | 16 +- .../Memory/ChatHistoryMemoryProviderTests.cs | 8 +- 33 files changed, 1305 insertions(+), 306 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 980a4eda40..47736911a5 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -52,7 +52,7 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(messages) { ResponseMessages = responseMessages }; @@ -84,7 +84,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(messages) { ResponseMessages = responseMessages }; diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index cf8e0dd943..c8abc45796 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -104,7 +104,7 @@ public UserInfoMemory(IChatClient chatClient, JsonElement serializedState, JsonS public UserInfo UserInfo { get; set; } - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { // Try and extract the user name and age from the message if we don't have it already and it's a user message. if ((this.UserInfo.UserName is null || this.UserInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) @@ -122,7 +122,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio } } - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { StringBuilder instructions = new(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs index b4dc0e8e0e..8df363fbf6 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs @@ -87,7 +87,7 @@ public VectorChatHistoryProvider(VectorStore vectorStore, JsonElement serialized public string? SessionDbKey { get; private set; } - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); @@ -105,7 +105,7 @@ public override async ValueTask> InvokingAsync(Invoking return messages; } - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { // Don't store messages if the request failed. if (context.InvokeException is not null) @@ -120,7 +120,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio // Add both request and response messages to the store // Optionally messages produced by the AIContextProvider can also be persisted (not shown). - var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem() { diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index c3b4d6f979..d3d85dc70f 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -92,7 +92,7 @@ public TodoListAIContextProvider(JsonElement jsonElement, JsonSerializerOptions? } } - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { StringBuilder outputMessageBuilder = new(); outputMessageBuilder.AppendLine("Your todo list contains the following items:"); @@ -132,7 +132,7 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio /// internal sealed class CalendarSearchAIContextProvider(Func> loadNextThreeCalendarEvents) : AIContextProvider { - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { var events = await loadNextThreeCalendarEvents(); @@ -179,7 +179,7 @@ public AggregatingAIContextProvider(ProviderFactory[] providerFactories, JsonEle .ToList(); } - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { // Invoke all the sub providers. var tasks = this._providers.Select(provider => provider.InvokingAsync(context, cancellationToken).AsTask()); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index f104f12890..53876358af 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -48,7 +49,51 @@ public abstract class AIContextProvider /// /// /// - public abstract ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); + public async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var aiContext = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false); + if (aiContext.Messages is null) + { + return aiContext; + } + + aiContext.Messages = aiContext.Messages.Select(message => + { + if (message.AdditionalProperties != null + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var source) + && source is AgentRequestMessageSource typedSource + && typedSource == AgentRequestMessageSource.AIContextProvider) + { + return message; + } + + message = message.Clone(); + message.AdditionalProperties ??= new(); + message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = AgentRequestMessageSource.AIContextProvider; + return message; + }).ToList(); + + return aiContext; + } + + /// + /// Called at the start of agent invocation to provide additional context. + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the with additional context to be used by the agent during this invocation. + /// + /// + /// Implementers can load any additional context required at this time, such as: + /// + /// Retrieving relevant information from knowledge bases + /// Adding system instructions or prompts + /// Providing function tools for the current invocation + /// Injecting contextual messages from conversation history + /// + /// + /// + protected abstract ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default); /// /// Called at the end of the agent invocation to process the invocation results. @@ -71,7 +116,31 @@ public abstract class AIContextProvider /// To check if the invocation was successful, inspect the property. /// /// - public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + => this.InvokedCoreAsync(context, cancellationToken); + + /// + /// Called at the end of the agent invocation to process the invocation results. + /// + /// Contains the invocation context including request messages, response messages, and any exception that occurred. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + /// + /// + /// Implementers can use the request and response messages in the provided to: + /// + /// Update internal state based on conversation outcomes + /// Extract and store memories or preferences from user messages + /// Log or audit conversation details + /// Perform cleanup or finalization tasks + /// + /// + /// + /// This method is called regardless of whether the invocation succeeded or failed. + /// To check if the invocation was successful, inspect the property. + /// + /// + protected virtual ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; /// @@ -117,7 +186,7 @@ public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOption => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about the invocation before the underlying AI model is invoked, including the messages @@ -146,7 +215,7 @@ public InvokingContext(IEnumerable requestMessages) } /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about a completed agent invocation, including both the @@ -159,12 +228,10 @@ public sealed class InvokedContext /// Initializes a new instance of the class with the specified request messages. /// /// The caller provided messages that were used by the agent for this invocation. - /// The messages provided by the for this invocation, if any. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? aiContextProviderMessages) + public InvokedContext(IEnumerable requestMessages) { this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); - this.AIContextProviderMessages = aiContextProviderMessages; } /// @@ -176,15 +243,6 @@ public InvokedContext(IEnumerable requestMessages, IEnumerable public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } - /// - /// Gets the messages provided by the for this invocation, if any. - /// - /// - /// A collection of instances that were provided by the , - /// and were used by the agent as part of the invocation. - /// - public IEnumerable? AIContextProviderMessages { get; set; } - /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs new file mode 100644 index 0000000000..488a55d405 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// An enumeration representing the source of an agent request message. +/// +/// +/// Input messages for a specific agent run can originate from various sources. +/// This enumeration helps to identify whether a message came from outside the agent pipeline, +/// whether it was produced by middleware, or came from chat history. +/// +public sealed class AgentRequestMessageSource : IEquatable +{ + /// + /// Provides the key used in to store the source of the agent request message. + /// + public static readonly string AdditionalPropertiesKey = "Agent.RequestMessageSource"; + + /// + /// Initializes a new instance of the class. + /// + /// The string value representing the source of the agent request message. + public AgentRequestMessageSource(string value) => this.Value = Throw.IfNullOrWhitespace(value); + + /// + /// Get the string value representing the source of the agent request message. + /// + public string Value { get; } + + /// + /// The message came from outside the agent pipeline (e.g., user input). + /// + public static AgentRequestMessageSource External { get; } = new AgentRequestMessageSource(nameof(External)); + + /// + /// The message was produced by middleware. + /// + public static AgentRequestMessageSource AIContextProvider { get; } = new AgentRequestMessageSource(nameof(AIContextProvider)); + + /// + /// The message came from chat history. + /// + public static AgentRequestMessageSource ChatHistory { get; } = new AgentRequestMessageSource(nameof(ChatHistory)); + + /// + /// Determines whether this instance and another specified object have the same value. + /// + /// The to compare to this instance. + /// if the value of the parameter is the same as the value of this instance; otherwise, . + public bool Equals(AgentRequestMessageSource? other) + { + if (other is null) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + return string.Equals(this.Value, other.Value, StringComparison.Ordinal); + } + + /// + /// Determines whether this instance and a specified object have the same value. + /// + /// The object to compare to this instance. + /// if is a and its value is the same as this instance; otherwise, . + public override bool Equals(object? obj) => this.Equals(obj as AgentRequestMessageSource); + + /// + /// Returns the hash code for this instance. + /// + /// A 32-bit signed integer hash code. + public override int GetHashCode() => this.Value?.GetHashCode() ?? 0; + + /// + /// Determines whether two specified objects have the same value. + /// + /// The first to compare. + /// The second to compare. + /// if the value of is the same as the value of ; otherwise, . + public static bool operator ==(AgentRequestMessageSource? left, AgentRequestMessageSource? right) + { + if (left is null) + { + return right is null; + } + + return left.Equals(right); + } + + /// + /// Determines whether two specified objects have different values. + /// + /// The first to compare. + /// The second to compare. + /// if the value of is different from the value of ; otherwise, . + public static bool operator !=(AgentRequestMessageSource? left, AgentRequestMessageSource? right) => !(left == right); +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index d809582ea4..2ef630cf1a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -65,7 +66,57 @@ public abstract class ChatHistoryProvider /// and context management. /// /// - public abstract ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); + public async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var messages = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false); + + return messages.Select(message => + { + if (message.AdditionalProperties != null + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var source) + && source is AgentRequestMessageSource typedSource + && typedSource == AgentRequestMessageSource.ChatHistory) + { + return message; + } + + message = message.Clone(); + message.AdditionalProperties ??= new(); + message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = AgentRequestMessageSource.ChatHistory; + return message; + }); + } + + /// + /// Called at the start of agent invocation to provide messages from the chat history as context for the next agent invocation. + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// + /// A task that represents the asynchronous operation. The task result contains a collection of + /// instances in ascending chronological order (oldest first). + /// + /// + /// + /// Messages are returned in chronological order to maintain proper conversation flow and context for the agent. + /// The oldest messages appear first in the collection, followed by more recent messages. + /// + /// + /// If the total message history becomes very large, implementations should apply appropriate strategies to manage + /// storage constraints, such as: + /// + /// Truncating older messages while preserving recent context + /// Summarizing message groups to maintain essential context + /// Implementing sliding window approaches for message retention + /// Archiving old messages while keeping active conversation context + /// + /// + /// + /// Each instance should be associated with a single to ensure proper message isolation + /// and context management. + /// + /// + protected abstract ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default); /// /// Called at the end of the agent invocation to add new messages to the chat history. @@ -77,7 +128,7 @@ public abstract class ChatHistoryProvider /// /// Messages should be added in the order they were generated to maintain proper chronological sequence. /// The is responsible for preserving message ordering and ensuring that subsequent calls to - /// return messages in the correct chronological order. + /// return messages in the correct chronological order. /// /// /// Implementations may perform additional processing during message addition, such as: @@ -92,7 +143,35 @@ public abstract class ChatHistoryProvider /// To check if the invocation was successful, inspect the property. /// /// - public abstract ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default); + public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) => + this.InvokedCoreAsync(context, cancellationToken); + + /// + /// Called at the end of the agent invocation to add new messages to the chat history. + /// + /// Contains the invocation context including request messages, response messages, and any exception that occurred. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous add operation. + /// + /// + /// Messages should be added in the order they were generated to maintain proper chronological sequence. + /// The is responsible for preserving message ordering and ensuring that subsequent calls to + /// return messages in the correct chronological order. + /// + /// + /// Implementations may perform additional processing during message addition, such as: + /// + /// Validating message content and metadata + /// Applying storage optimizations or compression + /// Triggering background maintenance operations + /// + /// + /// + /// This method is called regardless of whether the invocation succeeded or failed. + /// To check if the invocation was successful, inspect the property. + /// + /// + protected abstract ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default); /// /// Serializes the current object's state to a using the specified serialization options. @@ -131,7 +210,7 @@ public abstract class ChatHistoryProvider => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about the invocation including the new messages that will be used. @@ -160,7 +239,7 @@ public InvokingContext(IEnumerable requestMessages) } /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about a completed agent invocation, including both the @@ -173,12 +252,10 @@ public sealed class InvokedContext /// Initializes a new instance of the class with the specified request messages. /// /// The caller provided messages that were used by the agent for this invocation. - /// The messages retrieved from the for this invocation. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages) + public InvokedContext(IEnumerable requestMessages) { this.RequestMessages = Throw.IfNull(requestMessages); - this.ChatHistoryProviderMessages = chatHistoryProviderMessages; } /// @@ -190,24 +267,6 @@ public InvokedContext(IEnumerable requestMessages, IEnumerable public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } - /// - /// Gets the messages retrieved from the for this invocation, if any. - /// - /// - /// A collection of instances that were retrieved from the , - /// and were used by the agent as part of the invocation. - /// - public IEnumerable? ChatHistoryProviderMessages { get; set; } - - /// - /// Gets or sets the messages provided by the for this invocation, if any. - /// - /// - /// A collection of instances that were provided by the , - /// and were used by the agent as part of the invocation. - /// - public IEnumerable? AIContextProviderMessages { get; set; } - /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs index 0f5d9524cb..4cd8d570db 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -33,8 +34,8 @@ public static ChatHistoryProvider WithMessageFilters( } /// - /// Decorates the provided chat message so that it does not add - /// messages produced by any to chat history. + /// Decorates the provided so that it does not add + /// messages with to chat history. /// /// The to add the message filter to. /// A new instance that filters out messages so they do not get added. @@ -44,7 +45,7 @@ public static ChatHistoryProvider WithAIContextProviderMessageRemoval(this ChatH innerProvider: provider, invokedMessagesFilter: (ctx) => { - ctx.AIContextProviderMessages = null; + ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() != AgentRequestMessageSource.AIContextProvider); return ctx; }); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs index df7b536ea2..6cee80986b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs @@ -49,14 +49,14 @@ public ChatHistoryProviderMessageFilter( } /// - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { var messages = await this._innerProvider.InvokingAsync(context, cancellationToken).ConfigureAwait(false); return this._invokingMessagesFilter != null ? this._invokingMessagesFilter(messages) : messages; } /// - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { if (this._invokedMessagesFilter != null) { diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs new file mode 100644 index 0000000000..caf472faa9 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Conatins extension methods for +/// +public static class ChatMessageExtensions +{ + /// + /// Gets the source of the provided in the context of messages passed into an agent run. + /// + /// The for which we need the source. + /// An value indicating the source of the . Defaults to if no explicit source is defined. + public static AgentRequestMessageSource GetAgentRequestMessageSource(this ChatMessage message) + { + if (message.AdditionalProperties?.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var source) is true && source is AgentRequestMessageSource typedSource) + { + return typedSource; + } + + return AgentRequestMessageSource.External; + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index ab408c6a5e..001b3a3bcc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -133,7 +133,7 @@ public ChatMessage this[int index] } /// - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -146,7 +146,7 @@ public override async ValueTask> InvokingAsync(Invoking } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -155,8 +155,8 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio return; } - // Add request, AI context provider, and response messages to the provider - var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + // Add request and response messages to the provider + var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); this._messages.AddRange(allNewMessages); if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) @@ -229,7 +229,7 @@ public enum ChatReducerTriggerEvent { /// /// Trigger the reducer when a new message is added. - /// will only complete when reducer processing is done. + /// will only complete when reducer processing is done. /// AfterMessageAdded, diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index 41c9a211dc..85c5865f07 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -287,7 +287,7 @@ public static CosmosChatHistoryProvider CreateFromSerializedState(CosmosClient c } /// - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -347,7 +347,7 @@ public override async ValueTask> InvokingAsync(Invoking } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { Throw.IfNull(context); @@ -364,7 +364,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio } #pragma warning restore CA1513 - var messageList = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []).ToList(); + var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList(); if (messageList.Count == 0) { return; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index 0e9b4288b1..6a4113bd42 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -131,13 +131,16 @@ public Mem0Provider(HttpClient httpClient, JsonElement serializedState, JsonSeri } /// - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { Throw.IfNull(context); string queryText = string.Join( Environment.NewLine, - context.RequestMessages.Where(m => !string.IsNullOrWhiteSpace(m.Text)).Select(m => m.Text)); + context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => !string.IsNullOrWhiteSpace(m.Text)) + .Select(m => m.Text)); try { @@ -202,7 +205,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { if (context.InvokeException is not null) { @@ -212,7 +215,11 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio try { // Persist request and response messages after invocation. - await this.PersistMessagesAsync(context.RequestMessages.Concat(context.ResponseMessages ?? []), cancellationToken).ConfigureAwait(false); + await this.PersistMessagesAsync( + context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Concat(context.ResponseMessages ?? []), + cancellationToken).ConfigureAwait(false); } catch (Exception ex) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index afe6706553..f631de8e8a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -46,17 +46,17 @@ internal sealed class StoreState internal void AddMessages(params IEnumerable messages) => this._chatMessages.AddRange(messages); - public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) => new(this._chatMessages.AsReadOnly()); - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { if (context.InvokeException is not null) { return default; } - var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); this._chatMessages.AddRange(allNewMessages); return default; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 937a6d0f0b..4d9b8a1dee 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -207,8 +207,6 @@ protected override async IAsyncEnumerable RunCoreStreamingA (ChatClientAgentSession safeSession, ChatOptions? chatOptions, List inputMessagesForChatClient, - IList? aiContextProviderMessages, - IList? chatHistoryProviderMessages, ChatClientAgentContinuationToken? continuationToken) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -231,8 +229,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -246,8 +244,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -273,8 +271,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), cancellationToken).ConfigureAwait(false); throw; } } @@ -286,10 +284,10 @@ protected override async IAsyncEnumerable RunCoreStreamingA await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessagesForChatClient, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessagesForChatClient, continuationToken), chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -421,8 +419,6 @@ private async Task RunCoreAsync inputMessagesForChatClient, - IList? aiContextProviderMessages, - IList? chatHistoryProviderMessages, ChatClientAgentContinuationToken? _) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -442,8 +438,8 @@ private async Task RunCoreAsync RunCoreAsync RunCoreAsync inputMessages, - IList? aiContextProviderMessages, IEnumerable responseMessages, CancellationToken cancellationToken) { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, + await session.AIContextProvider.InvokedAsync(new(inputMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -496,12 +491,11 @@ private static async Task NotifyAIContextProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable inputMessages, - IList? aiContextProviderMessages, CancellationToken cancellationToken) { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { InvokeException = ex }, + await session.AIContextProvider.InvokedAsync(new(inputMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -671,8 +665,6 @@ private async Task ChatClientAgentSession AgentSession, ChatOptions? ChatOptions, List InputMessagesForChatClient, - IList? AIContextProviderMessages, - IList? ChatHistoryProviderMessages, ChatClientAgentContinuationToken? ContinuationToken )> PrepareSessionAndMessagesAsync( AgentSession? session, @@ -702,8 +694,6 @@ private async Task } List inputMessagesForChatClient = []; - IList? aiContextProviderMessages = null; - IList? chatHistoryProviderMessages = null; // Populate the session messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) @@ -716,7 +706,6 @@ private async Task var invokingContext = new ChatHistoryProvider.InvokingContext(inputMessages); var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); inputMessagesForChatClient.AddRange(providerMessages); - chatHistoryProviderMessages = providerMessages as IList ?? providerMessages.ToList(); } // Add the input messages before getting context from AIContextProvider. @@ -731,7 +720,6 @@ private async Task if (aiContext.Messages is { Count: > 0 }) { inputMessagesForChatClient.AddRange(aiContext.Messages); - aiContextProviderMessages = aiContext.Messages; } if (aiContext.Tools is { Count: > 0 }) @@ -770,7 +758,7 @@ private async Task chatOptions.ConversationId = typedSession.ConversationId; } - return (typedSession, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatHistoryProviderMessages, continuationToken); + return (typedSession, chatOptions, inputMessagesForChatClient, continuationToken); } private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) @@ -803,8 +791,6 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable requestMessages, - IEnumerable? chatHistoryProviderMessages, - IEnumerable? aiContextProviderMessages, ChatOptions? chatOptions, CancellationToken cancellationToken) { @@ -814,9 +800,8 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, InvokeException = ex }; @@ -829,8 +814,6 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( private static Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatClientAgentSession session, IEnumerable requestMessages, - IEnumerable? chatHistoryProviderMessages, - IEnumerable? aiContextProviderMessages, IEnumerable responseMessages, ChatOptions? chatOptions, CancellationToken cancellationToken) @@ -841,9 +824,8 @@ private static Task NotifyChatHistoryProviderOfNewMessagesAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; return provider.InvokedAsync(invokedContext, cancellationToken).AsTask(); diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 87adc9fd7a..7476aac267 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -25,8 +25,8 @@ namespace Microsoft.Agents.AI; /// abstractions to work with any compatible vector store implementation. /// /// -/// Messages are stored during the method and retrieved during the -/// method using semantic similarity search. +/// Messages are stored during the method and retrieved during the +/// method using semantic similarity search. /// /// /// Behavior is configurable through . When @@ -175,7 +175,7 @@ private ChatHistoryMemoryProvider( } /// - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -189,6 +189,7 @@ public override async ValueTask InvokingAsync(InvokingContext context { // Get the text from the current request messages var requestText = string.Join("\n", context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) .Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); @@ -228,7 +229,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -244,6 +245,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); List> itemsToStore = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) .Concat(context.ResponseMessages ?? []) .Select(message => new Dictionary { diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index be9eba1365..cff79c9e7a 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -107,7 +107,7 @@ public TextSearchProvider( } /// - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke) { @@ -117,7 +117,9 @@ public override async ValueTask InvokingAsync(InvokingContext context // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); - var requestMessagesText = context.RequestMessages.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); + var requestMessagesText = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); foreach (var messageText in this._recentMessagesText.Concat(requestMessagesText)) { if (sbInput.Length > 0) @@ -166,7 +168,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } /// - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { int limit = this._recentMessageMemoryLimit; if (limit <= 0) @@ -180,6 +182,7 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken } var messagesText = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) .Concat(context.ResponseMessages ?? []) .Where(m => this._recentMessageRolesIncluded.Contains(m.Role) && diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index b287c8b304..7ce9f06bbb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -15,7 +15,7 @@ public async Task InvokedAsync_ReturnsCompletedTaskAsync() { var provider = new TestAIContextProvider(); var messages = new ReadOnlyCollection([]); - var task = provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + var task = provider.InvokedAsync(new(messages)); Assert.Equal(default, task); } @@ -36,7 +36,7 @@ public void InvokingContext_Constructor_ThrowsForNullMessages() [Fact] public void InvokedContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, aiContextProviderMessages: null)); + Assert.Throws(() => new AIContextProvider.InvokedContext(null!)); } #region GetService Method Tests @@ -157,7 +157,7 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() private sealed class TestAIContextProvider : AIContextProvider { - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs new file mode 100644 index 0000000000..faf2ab6fc0 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs @@ -0,0 +1,489 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class AgentRequestMessageSourceTests +{ + #region Constructor Tests + + [Fact] + public void Constructor_WithValue_SetsValueProperty() + { + // Arrange + const string ExpectedValue = "CustomSource"; + + // Act + AgentRequestMessageSource source = new(ExpectedValue); + + // Assert + Assert.Equal(ExpectedValue, source.Value); + } + + [Fact] + public void Constructor_WithNullValue_Throws() + { + // Act & Assert + Assert.Throws(() => new AgentRequestMessageSource(null!)); + } + + [Fact] + public void Constructor_WithEmptyValue_Throws() + { + // Act & Assert + Assert.Throws(() => new AgentRequestMessageSource(string.Empty)); + } + + #endregion + + #region Static Properties Tests + + [Fact] + public void External_ReturnsInstanceWithExternalValue() + { + // Arrange & Act + AgentRequestMessageSource source = AgentRequestMessageSource.External; + + // Assert + Assert.NotNull(source); + Assert.Equal("External", source.Value); + } + + [Fact] + public void AIContextProvider_ReturnsInstanceWithAIContextProviderValue() + { + // Arrange & Act + AgentRequestMessageSource source = AgentRequestMessageSource.AIContextProvider; + + // Assert + Assert.NotNull(source); + Assert.Equal("AIContextProvider", source.Value); + } + + [Fact] + public void ChatHistory_ReturnsInstanceWithChatHistoryValue() + { + // Arrange & Act + AgentRequestMessageSource source = AgentRequestMessageSource.ChatHistory; + + // Assert + Assert.NotNull(source); + Assert.Equal("ChatHistory", source.Value); + } + + [Fact] + public void AdditionalPropertiesKey_ReturnsExpectedValue() + { + // Arrange & Act + string key = AgentRequestMessageSource.AdditionalPropertiesKey; + + // Assert + Assert.Equal("Agent.RequestMessageSource", key); + } + + [Fact] + public void StaticProperties_ReturnSameInstanceOnMultipleCalls() + { + // Arrange & Act + AgentRequestMessageSource external1 = AgentRequestMessageSource.External; + AgentRequestMessageSource external2 = AgentRequestMessageSource.External; + AgentRequestMessageSource aiContextProvider1 = AgentRequestMessageSource.AIContextProvider; + AgentRequestMessageSource aiContextProvider2 = AgentRequestMessageSource.AIContextProvider; + AgentRequestMessageSource chatHistory1 = AgentRequestMessageSource.ChatHistory; + AgentRequestMessageSource chatHistory2 = AgentRequestMessageSource.ChatHistory; + + // Assert + Assert.Same(external1, external2); + Assert.Same(aiContextProvider1, aiContextProvider2); + Assert.Same(chatHistory1, chatHistory2); + } + + #endregion + + #region Equals Tests + + [Fact] + public void Equals_WithSameInstance_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource source = new("Test"); + + // Act + bool result = source.Equals(source); + + // Assert + Assert.True(result); + } + + [Fact] + public void Equals_WithEqualValue_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource source2 = new("Test"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.True(result); + } + + [Fact] + public void Equals_WithDifferentValue_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source1 = new("Test1"); + AgentRequestMessageSource source2 = new("Test2"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.False(result); + } + + [Fact] + public void Equals_WithNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source = new("Test"); + + // Act + bool result = source.Equals(null); + + // Assert + Assert.False(result); + } + + [Fact] + public void Equals_WithDifferentCase_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource source2 = new("test"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.False(result); + } + + [Fact] + public void Equals_StaticExternalWithNewInstanceHavingSameValue_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource external = AgentRequestMessageSource.External; + AgentRequestMessageSource newExternal = new("External"); + + // Act + bool result = external.Equals(newExternal); + + // Assert + Assert.True(result); + } + + #endregion + + #region Object.Equals Tests + + [Fact] + public void ObjectEquals_WithEqualAgentRequestMessageSource_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + object source2 = new AgentRequestMessageSource("Test"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.True(result); + } + + [Fact] + public void ObjectEquals_WithDifferentType_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source = new("Test"); + object other = "Test"; + + // Act + bool result = source.Equals(other); + + // Assert + Assert.False(result); + } + + [Fact] + public void ObjectEquals_WithNullObject_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source = new("Test"); + object? other = null; + + // Act + bool result = source.Equals(other); + + // Assert + Assert.False(result); + } + + #endregion + + #region GetHashCode Tests + + [Fact] + public void GetHashCode_WithSameValue_ReturnsSameHashCode() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource source2 = new("Test"); + + // Act + int hashCode1 = source1.GetHashCode(); + int hashCode2 = source2.GetHashCode(); + + // Assert + Assert.Equal(hashCode1, hashCode2); + } + + [Fact] + public void GetHashCode_WithDifferentValue_ReturnsDifferentHashCode() + { + // Arrange + AgentRequestMessageSource source1 = new("Test1"); + AgentRequestMessageSource source2 = new("Test2"); + + // Act + int hashCode1 = source1.GetHashCode(); + int hashCode2 = source2.GetHashCode(); + + // Assert + Assert.NotEqual(hashCode1, hashCode2); + } + + [Fact] + public void GetHashCode_ConsistentWithEquals() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource source2 = new("Test"); + + // Act & Assert + // If two objects are equal, they must have the same hash code + Assert.True(source1.Equals(source2)); + Assert.Equal(source1.GetHashCode(), source2.GetHashCode()); + } + + #endregion + + #region Equality Operator Tests + + [Fact] + public void EqualityOperator_WithEqualValues_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource source2 = new("Test"); + + // Act + bool result = source1 == source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void EqualityOperator_WithDifferentValues_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source1 = new("Test1"); + AgentRequestMessageSource source2 = new("Test2"); + + // Act + bool result = source1 == source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void EqualityOperator_WithBothNull_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource? source1 = null; + AgentRequestMessageSource? source2 = null; + + // Act + bool result = source1 == source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void EqualityOperator_WithLeftNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource? source1 = null; + AgentRequestMessageSource source2 = new("Test"); + + // Act + bool result = source1 == source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void EqualityOperator_WithRightNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource? source2 = null; + + // Act + bool result = source1 == source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void EqualityOperator_WithStaticInstances_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource external1 = AgentRequestMessageSource.External; + AgentRequestMessageSource external2 = AgentRequestMessageSource.External; + + // Act + bool result = external1 == external2; + + // Assert + Assert.True(result); + } + + [Fact] + public void EqualityOperator_StaticWithNewInstanceHavingSameValue_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource external = AgentRequestMessageSource.External; + AgentRequestMessageSource newExternal = new("External"); + + // Act + bool result = external == newExternal; + + // Assert + Assert.True(result); + } + + #endregion + + #region Inequality Operator Tests + + [Fact] + public void InequalityOperator_WithEqualValues_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource source2 = new("Test"); + + // Act + bool result = source1 != source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void InequalityOperator_WithDifferentValues_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource source1 = new("Test1"); + AgentRequestMessageSource source2 = new("Test2"); + + // Act + bool result = source1 != source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void InequalityOperator_WithBothNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSource? source1 = null; + AgentRequestMessageSource? source2 = null; + + // Act + bool result = source1 != source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void InequalityOperator_WithLeftNull_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource? source1 = null; + AgentRequestMessageSource source2 = new("Test"); + + // Act + bool result = source1 != source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void InequalityOperator_WithRightNull_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource source1 = new("Test"); + AgentRequestMessageSource? source2 = null; + + // Act + bool result = source1 != source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void InequalityOperator_DifferentStaticInstances_ReturnsTrue() + { + // Arrange + AgentRequestMessageSource external = AgentRequestMessageSource.External; + AgentRequestMessageSource chatHistory = AgentRequestMessageSource.ChatHistory; + + // Act + bool result = external != chatHistory; + + // Assert + Assert.True(result); + } + + #endregion + + #region IEquatable Tests + + [Fact] + public void IEquatable_ImplementedCorrectly() + { + // Arrange + AgentRequestMessageSource source = new("Test"); + + // Act & Assert + Assert.IsAssignableFrom>(source); + } + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs index 84a0242320..db366a4b9b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -38,7 +39,8 @@ public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync() ChatHistoryProvider.InvokingContext context = new([new ChatMessage(ChatRole.User, "Test")]); providerMock - .Setup(p => p.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(innerMessages); ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters( @@ -57,16 +59,20 @@ public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() { // Arrange Mock providerMock = new(); - List requestMessages = [new(ChatRole.User, "Hello")]; - List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + List requestMessages = + [ + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.User, "Hello") + ]; + ChatHistoryProvider.InvokedContext context = new(requestMessages) { ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")] }; ChatHistoryProvider.InvokedContext? capturedContext = null; providerMock - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, _) => capturedContext = ctx) .Returns(default(ValueTask)); @@ -103,17 +109,18 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe { // Arrange Mock providerMock = new(); - List requestMessages = [new(ChatRole.User, "Hello")]; - List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - List aiContextProviderMessages = [new(ChatRole.System, "Context")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) - { - AIContextProviderMessages = aiContextProviderMessages - }; + List requestMessages = + [ + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.User, "Hello"), + new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.AIContextProvider } } } + ]; + ChatHistoryProvider.InvokedContext context = new(requestMessages); ChatHistoryProvider.InvokedContext? capturedContext = null; providerMock - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, _) => capturedContext = ctx) .Returns(default(ValueTask)); @@ -124,6 +131,8 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe // Assert Assert.NotNull(capturedContext); - Assert.Null(capturedContext.AIContextProviderMessages); + Assert.Equal(2, capturedContext.RequestMessages.Count()); + Assert.Contains("System", capturedContext.RequestMessages.Select(x => x.Text)); + Assert.Contains("Hello", capturedContext.RequestMessages.Select(x => x.Text)); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 43a3e78f10..6998a91cf2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -62,7 +63,8 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); innerProviderMock - .Setup(s => s.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(expectedMessages); var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x); @@ -74,7 +76,9 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn Assert.Equal(2, result.Count); Assert.Equal("Hello", result[0].Text); Assert.Equal("Hi there!", result[1].Text); - innerProviderMock.Verify(s => s.InvokingAsync(context, It.IsAny()), Times.Once); + innerProviderMock + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -91,7 +95,8 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); innerProviderMock - .Setup(s => s.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(innerMessages); // Filter to only user messages @@ -105,7 +110,9 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() // Assert Assert.Equal(2, result.Count); Assert.All(result, msg => Assert.Equal(ChatRole.User, msg.Role)); - innerProviderMock.Verify(s => s.InvokingAsync(context, It.IsAny()), Times.Once); + innerProviderMock + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -121,7 +128,8 @@ public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); innerProviderMock - .Setup(s => s.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(innerMessages); // Filter that transforms messages @@ -144,28 +152,31 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() { // Arrange var innerProviderMock = new Mock(); - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var chatHistoryProviderMessages = new List { new(ChatRole.System, "System") }; + List requestMessages = + [ + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.User, "Hello"), + ]; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages) + var context = new ChatHistoryProvider.InvokedContext(requestMessages) { ResponseMessages = responseMessages }; ChatHistoryProvider.InvokedContext? capturedContext = null; innerProviderMock - .Setup(s => s.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, ct) => capturedContext = ctx) .Returns(default(ValueTask)); // Filter that modifies the context ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) { - var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); - return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages, ctx.ChatHistoryProviderMessages) + var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() == AgentRequestMessageSource.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); + return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages) { ResponseMessages = ctx.ResponseMessages, - AIContextProviderMessages = ctx.AIContextProviderMessages, InvokeException = ctx.InvokeException }; } @@ -179,7 +190,9 @@ ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedCont Assert.NotNull(capturedContext); Assert.Single(capturedContext.RequestMessages); Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text); - innerProviderMock.Verify(s => s.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + innerProviderMock + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 02955f4a25..4f981b9692 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -78,10 +78,10 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() private sealed class TestChatHistoryProvider : ChatHistoryProvider { - public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) => new(Array.Empty()); - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs new file mode 100644 index 0000000000..269ef03ee1 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class ChatMessageExtensionsTests +{ + #region GetAgentRequestMessageSource Tests + + [Fact] + public void GetAgentRequestMessageSource_WithNoAdditionalProperties_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello"); + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithNullAdditionalProperties_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = null + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithEmptyAdditionalProperties_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary() + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithExternalSource_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.External } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithAIContextProviderSource_ReturnsAIContextProvider() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.AIContextProvider } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.AIContextProvider, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithChatHistorySource_ReturnsChatHistory() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.ChatHistory, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithCustomSource_ReturnsCustomSource() + { + // Arrange + AgentRequestMessageSource customSource = new("CustomSource"); + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSource.AdditionalPropertiesKey, customSource } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(customSource, result); + Assert.Equal("CustomSource", result.Value); + } + + [Fact] + public void GetAgentRequestMessageSource_WithWrongKeyType_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSource.AdditionalPropertiesKey, "NotAnAgentRequestMessageSource" } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithNullValue_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSource.AdditionalPropertiesKey, null! } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithMultipleProperties_ReturnsCorrectSource() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { "OtherProperty", "SomeValue" }, + { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory }, + { "AnotherProperty", 123 } + } + }; + + // Act + AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSource.ChatHistory, result); + } + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index debaff73ef..40ccee086a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -51,7 +51,8 @@ public async Task InvokedAsyncAddsMessagesAsync() { var requestMessages = new List { - new(ChatRole.User, "Hello") + new(ChatRole.User, "Hello"), + new(ChatRole.System, "additional context") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, }; var responseMessages = new List { @@ -61,16 +62,11 @@ public async Task InvokedAsyncAddsMessagesAsync() { new(ChatRole.System, "original instructions") }; - var aiContextProviderMessages = new List() - { - new(ChatRole.System, "additional context") - }; var provider = new InMemoryChatHistoryProvider(); provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, providerMessages) + var context = new ChatHistoryProvider.InvokedContext(requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; await provider.InvokedAsync(context, CancellationToken.None); @@ -87,7 +83,7 @@ public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext([], []); + var context = new ChatHistoryProvider.InvokedContext([]); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -183,7 +179,7 @@ public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() var provider = new InMemoryChatHistoryProvider(); var messages = new List(); - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(messages); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -520,7 +516,7 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(originalMessages); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -579,7 +575,7 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(originalMessages); await provider.InvokedAsync(context, CancellationToken.None); // Assert diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index da65d53c30..4d178280bc 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -2930,7 +2930,7 @@ private sealed class TestSchema /// private sealed class TestAIContextProvider : AIContextProvider { - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { return new ValueTask(new AIContext()); } @@ -2941,12 +2941,12 @@ public override ValueTask InvokingAsync(InvokingContext context, Canc /// private sealed class TestChatHistoryProvider : ChatHistoryProvider { - public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { return new ValueTask>(Array.Empty()); } - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index ab2f58dfd5..88dcca229f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -214,7 +214,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext([message], []) + var context = new ChatHistoryProvider.InvokedContext([message]) { ResponseMessages = [] }; @@ -282,20 +282,16 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() { new ChatMessage(ChatRole.User, "First message"), new ChatMessage(ChatRole.Assistant, "Second message"), - new ChatMessage(ChatRole.User, "Third message") - }; - var aiContextProviderMessages = new[] - { - new ChatMessage(ChatRole.System, "System context message") + new ChatMessage(ChatRole.User, "Third message"), + new ChatMessage(ChatRole.System, "System context message") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.AIContextProvider } } } }; var responseMessages = new[] { new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; @@ -346,8 +342,8 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")]); + var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")]); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -391,7 +387,7 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, []); + var invokedContext = new ChatHistoryProvider.InvokedContext(messages); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added @@ -545,7 +541,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext([message], []); + var context = new ChatHistoryProvider.InvokedContext([message]); // Act await provider.InvokedAsync(context); @@ -602,7 +598,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(messages); // Act await provider.InvokedAsync(context); @@ -637,8 +633,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")]); + var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")]); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -675,7 +671,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); - var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []); + var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")]); await originalStore.InvokedAsync(context); // Act - Serialize the provider state @@ -717,8 +713,8 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")]); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")]); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -760,7 +756,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(messages); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -798,7 +794,7 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(messages); await provider.InvokedAsync(context); // Wait for eventual consistency diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index bacc59833a..d963585e91 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -53,7 +53,7 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([input], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext([input])); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); @@ -77,7 +77,7 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro])); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); @@ -105,7 +105,7 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro])); var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 832881857d..7bc0ed98a4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -215,7 +215,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -242,7 +242,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -268,7 +268,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -318,7 +318,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 6c2be9689a..09d2a419f0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.UnitTests; @@ -342,7 +343,8 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, @@ -350,7 +352,8 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); @@ -378,12 +381,15 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Equal("context provider message", chatHistoryProvider[1].Text); Assert.Equal("response", chatHistoryProvider[2].Text); - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages == responseMessages && - x.InvokeException == null), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages == responseMessages && + x.InvokeException == null), ItExpr.IsAny()); } /// @@ -394,7 +400,6 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() { // Arrange ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; - ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")]; ChatMessage[] aiContextProviderMessages = [new(ChatRole.System, "context provider message")]; Mock mockService = new(); mockService @@ -406,13 +411,15 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); @@ -421,12 +428,15 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); // Assert - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages == null && - x.InvokeException is InvalidOperationException), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages == null && + x.InvokeException is InvalidOperationException), ItExpr.IsAny()); } /// @@ -458,7 +468,8 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext()); ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); @@ -474,7 +485,9 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA Assert.Equal(ChatRole.User, capturedMessages[0].Role); Assert.Single(capturedTools); Assert.Contains(capturedTools, t => t.Name == "base function"); - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } #endregion @@ -1371,7 +1384,8 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, @@ -1379,7 +1393,8 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new( @@ -1414,13 +1429,16 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Equal("context provider message", chatHistoryProvider[1].Text); Assert.Equal("response", chatHistoryProvider[2].Text); - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages!.Count() == 1 && - x.ResponseMessages!.ElementAt(0).Text == "response" && - x.InvokeException == null), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages!.Count() == 1 && + x.ResponseMessages!.ElementAt(0).Text == "response" && + x.InvokeException == null), ItExpr.IsAny()); } /// @@ -1442,13 +1460,15 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new( @@ -1467,12 +1487,15 @@ await Assert.ThrowsAsync(async () => }); // Assert - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages == null && - x.InvokeException is InvalidOperationException), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages == null && + x.InvokeException is InvalidOperationException), ItExpr.IsAny()); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 87be3fb96e..2eed890292 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.UnitTests; @@ -339,13 +340,15 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu // Create a mock chat history provider that would normally provide messages var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync([new(ChatRole.User, "Message from chat history provider")]); // Create a mock AI context provider that would normally provide context var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = [new(ChatRole.System, "Message from AI context")], @@ -385,14 +388,14 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu Assert.Empty(capturedMessages); // Verify that chat history provider was never called due to continuation token - mockChatHistoryProvider.Verify( - ms => ms.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); // Verify that AI context provider was never called due to continuation token - mockContextProvider.Verify( - p => p.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockContextProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -404,13 +407,15 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe // Create a mock chat history provider that would normally provide messages var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync([new(ChatRole.User, "Message from chat history provider")]); // Create a mock AI context provider that would normally provide context var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = [new(ChatRole.System, "Message from AI context")], @@ -449,14 +454,14 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe Assert.Empty(capturedMessages); // Verify that chat history provider was never called due to continuation token - mockChatHistoryProvider.Verify( - ms => ms.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); // Verify that AI context provider was never called due to continuation token - mockContextProvider.Verify( - p => p.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockContextProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -633,14 +638,16 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, ct) => capturedMessagesAddedToProvider.AddRange(ctx.ResponseMessages ?? [])) .Returns(new ValueTask()); AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); mockContextProvider - .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); @@ -662,11 +669,15 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial await agent.RunStreamingAsync(session, options: runOptions).ToListAsync(); // Assert - mockChatHistoryProvider.Verify(ms => ms.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.Single(capturedMessagesAddedToProvider); Assert.Contains("once upon a time", capturedMessagesAddedToProvider[0].Text); - mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockContextProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.NotNull(capturedInvokedContext?.ResponseMessages); Assert.Single(capturedInvokedContext.ResponseMessages); Assert.Contains("once upon a time", capturedInvokedContext.ResponseMessages.ElementAt(0).Text); @@ -689,14 +700,16 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, ct) => capturedMessagesAddedToProvider.AddRange(ctx.RequestMessages)) .Returns(new ValueTask()); AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); mockContextProvider - .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); @@ -718,11 +731,15 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI await agent.RunStreamingAsync(session, options: runOptions).ToListAsync(); // Assert - mockChatHistoryProvider.Verify(ms => ms.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.Single(capturedMessagesAddedToProvider); Assert.Contains("Tell me a story", capturedMessagesAddedToProvider[0].Text); - mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockContextProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.NotNull(capturedInvokedContext?.RequestMessages); Assert.Single(capturedInvokedContext.RequestMessages); Assert.Contains("Tell me a story", capturedInvokedContext.RequestMessages.ElementAt(0).Text); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index e2b7313e7f..cd07e64b87 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; using Xunit.Sdk; namespace Microsoft.Agents.AI.UnitTests; @@ -183,12 +184,14 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); Mock mockChatHistoryProvider = new(); - mockChatHistoryProvider.Setup(s => s.InvokingAsync( - It.IsAny(), - It.IsAny())).ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); - mockChatHistoryProvider.Setup(s => s.InvokedAsync( - It.IsAny(), - It.IsAny())).Returns(new ValueTask()); + mockChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + mockChatHistoryProvider + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); Mock>> mockFactory = new(); mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); @@ -211,14 +214,16 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.IsAny(), It.IsAny()), Times.Once); - mockChatHistoryProvider.Verify(s => s.InvokingAsync( - It.Is(x => x.RequestMessages.Count() == 1), - It.IsAny()), - Times.Once); - mockChatHistoryProvider.Verify(s => s.InvokedAsync( - It.Is(x => x.RequestMessages.Count() == 1 && x.ChatHistoryProviderMessages != null && x.ChatHistoryProviderMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), - It.IsAny()), - Times.Once); + mockChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1), + ItExpr.IsAny()); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 2 && x.ResponseMessages!.Count() == 1), + ItExpr.IsAny()); mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } @@ -253,10 +258,11 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() // Assert Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); - mockChatHistoryProvider.Verify(s => s.InvokedAsync( - It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), - It.IsAny()), - Times.Once); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), + ItExpr.IsAny()); mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } @@ -308,22 +314,26 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi // Arrange a chat history provider to override the factory provided one. Mock mockOverrideChatHistoryProvider = new(); - mockOverrideChatHistoryProvider.Setup(s => s.InvokingAsync( - It.IsAny(), - It.IsAny())).ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); - mockOverrideChatHistoryProvider.Setup(s => s.InvokedAsync( - It.IsAny(), - It.IsAny())).Returns(new ValueTask()); + mockOverrideChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + mockOverrideChatHistoryProvider + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); // Arrange a chat history provider to provide to the agent via a factory at construction time. // This one shouldn't be used since it is being overridden. Mock mockFactoryChatHistoryProvider = new(); - mockFactoryChatHistoryProvider.Setup(s => s.InvokingAsync( - It.IsAny(), - It.IsAny())).ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - mockFactoryChatHistoryProvider.Setup(s => s.InvokedAsync( - It.IsAny(), - It.IsAny())).Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); + mockFactoryChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); + mockFactoryChatHistoryProvider + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); Mock>> mockFactory = new(); mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); @@ -348,23 +358,27 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi It.IsAny(), It.IsAny()), Times.Once); - mockOverrideChatHistoryProvider.Verify(s => s.InvokingAsync( - It.Is(x => x.RequestMessages.Count() == 1), - It.IsAny()), - Times.Once); - mockOverrideChatHistoryProvider.Verify(s => s.InvokedAsync( - It.Is(x => x.RequestMessages.Count() == 1 && x.ChatHistoryProviderMessages != null && x.ChatHistoryProviderMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), - It.IsAny()), - Times.Once); - - mockFactoryChatHistoryProvider.Verify(s => s.InvokingAsync( - It.IsAny(), - It.IsAny()), - Times.Never); - mockFactoryChatHistoryProvider.Verify(s => s.InvokedAsync( - It.IsAny(), - It.IsAny()), - Times.Never); + mockOverrideChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1), + ItExpr.IsAny()); + mockOverrideChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 2 && x.ResponseMessages!.Count() == 1), + ItExpr.IsAny()); + + mockFactoryChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Never(), + ItExpr.IsAny(), + ItExpr.IsAny()); + mockFactoryChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Never(), + ItExpr.IsAny(), + ItExpr.IsAny()); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 3698ee7065..9863479101 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -340,7 +340,7 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); + await provider.InvokedAsync(new(initialMessages) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( [ @@ -380,7 +380,7 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( [ @@ -417,7 +417,7 @@ await provider.InvokedAsync(new( [ new ChatMessage(ChatRole.User, "A"), new ChatMessage(ChatRole.Assistant, "B"), - ], aiContextProviderMessages: null)); + ])); // Second memory update (C,D,E) await provider.InvokedAsync(new( @@ -425,7 +425,7 @@ await provider.InvokedAsync(new( new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), new ChatMessage(ChatRole.User, "E"), - ], aiContextProviderMessages: null)); + ])); var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "F")]); @@ -462,7 +462,7 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(initialMessages, null)); + await provider.InvokedAsync(new(initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( [ @@ -518,7 +518,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); // Populate recent memory. + await provider.InvokedAsync(new(messages)); // Populate recent memory. var state = provider.Serialize(); // Assert @@ -547,7 +547,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(messages)); // Act var state = provider.Serialize(); @@ -588,7 +588,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn new ChatMessage(ChatRole.Assistant, "L4"), new ChatMessage(ChatRole.User, "L5"), }; - await initialProvider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await initialProvider.InvokedAsync(new(messages)); var state = initialProvider.Serialize(); string? capturedInput = null; diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index f46538c8e4..c6dcd3ea51 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -116,7 +116,7 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls]) { ResponseMessages = [responseMsg] }; @@ -174,7 +174,7 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" }); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext([requestMsg]) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -203,7 +203,7 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() new ChatHistoryMemoryProviderScope() { UserId = "UID" }, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext([requestMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -254,7 +254,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext([requestMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); From 8842acfd3fd3f9a6ba6c2ac06db80268a52e915a Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:16:42 +0000 Subject: [PATCH 2/5] Add support for source, in addition to source type, and add unit tests for automatic stamping --- .../AIContextProvider.cs | 31 ++- .../AgentRequestMessageSource.cs | 94 +--------- .../AgentRequestMessageSourceType.cs | 106 +++++++++++ .../ChatHistoryProvider.cs | 31 ++- .../ChatHistoryProviderExtensions.cs | 4 +- .../ChatMessageExtensions.cs | 10 +- .../Microsoft.Agents.AI.Mem0/Mem0Provider.cs | 4 +- .../Memory/ChatHistoryMemoryProvider.cs | 4 +- .../Microsoft.Agents.AI/TextSearchProvider.cs | 4 +- .../AIContextProviderTests.cs | 176 +++++++++++++++++- .../AgentRequestMessageSourceTests.cs | 126 ++++++------- .../ChatHistoryProviderExtensionsTests.cs | 6 +- .../ChatHistoryProviderMessageFilterTests.cs | 4 +- .../ChatHistoryProviderTests.cs | 141 +++++++++++++- .../ChatMessageExtensionsTests.cs | 54 +++--- .../InMemoryChatHistoryProviderTests.cs | 2 +- .../CosmosChatHistoryProviderTests.cs | 2 +- 17 files changed, 586 insertions(+), 213 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index 53876358af..377fde1b8e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -32,6 +32,25 @@ namespace Microsoft.Agents.AI; /// public abstract class AIContextProvider { + private readonly string _sourceName; + + /// + /// Initializes a new instance of the class. + /// + protected AIContextProvider() + { + this._sourceName = this.GetType().FullName!; + } + + /// + /// Initializes a new instance of the class with the specified source name. + /// + /// The source name to stamp on for each messages produced by the . + protected AIContextProvider(string sourceName) + { + this._sourceName = sourceName; + } + /// /// Called at the start of agent invocation to provide additional context. /// @@ -60,16 +79,20 @@ public async ValueTask InvokingAsync(InvokingContext context, Cancell aiContext.Messages = aiContext.Messages.Select(message => { if (message.AdditionalProperties != null - && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var source) - && source is AgentRequestMessageSource typedSource - && typedSource == AgentRequestMessageSource.AIContextProvider) + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType) + && messageSourceType is AgentRequestMessageSourceType typedMessageSourceType + && typedMessageSourceType == AgentRequestMessageSourceType.AIContextProvider + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource) + && messageSource is string typedMessageSource + && typedMessageSource == this._sourceName) { return message; } message = message.Clone(); message.AdditionalProperties ??= new(); - message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = AgentRequestMessageSource.AIContextProvider; + message.AdditionalProperties[AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.AIContextProvider; + message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = this._sourceName; return message; }).ToList(); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs index 488a55d405..127f1c1b8d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs @@ -1,106 +1,16 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using Microsoft.Extensions.AI; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; /// -/// An enumeration representing the source of an agent request message. +/// Provides a constant for the key used to store the source of the agent request message. /// -/// -/// Input messages for a specific agent run can originate from various sources. -/// This enumeration helps to identify whether a message came from outside the agent pipeline, -/// whether it was produced by middleware, or came from chat history. -/// -public sealed class AgentRequestMessageSource : IEquatable +public static class AgentRequestMessageSource { /// /// Provides the key used in to store the source of the agent request message. /// public static readonly string AdditionalPropertiesKey = "Agent.RequestMessageSource"; - - /// - /// Initializes a new instance of the class. - /// - /// The string value representing the source of the agent request message. - public AgentRequestMessageSource(string value) => this.Value = Throw.IfNullOrWhitespace(value); - - /// - /// Get the string value representing the source of the agent request message. - /// - public string Value { get; } - - /// - /// The message came from outside the agent pipeline (e.g., user input). - /// - public static AgentRequestMessageSource External { get; } = new AgentRequestMessageSource(nameof(External)); - - /// - /// The message was produced by middleware. - /// - public static AgentRequestMessageSource AIContextProvider { get; } = new AgentRequestMessageSource(nameof(AIContextProvider)); - - /// - /// The message came from chat history. - /// - public static AgentRequestMessageSource ChatHistory { get; } = new AgentRequestMessageSource(nameof(ChatHistory)); - - /// - /// Determines whether this instance and another specified object have the same value. - /// - /// The to compare to this instance. - /// if the value of the parameter is the same as the value of this instance; otherwise, . - public bool Equals(AgentRequestMessageSource? other) - { - if (other is null) - { - return false; - } - - if (ReferenceEquals(this, other)) - { - return true; - } - - return string.Equals(this.Value, other.Value, StringComparison.Ordinal); - } - - /// - /// Determines whether this instance and a specified object have the same value. - /// - /// The object to compare to this instance. - /// if is a and its value is the same as this instance; otherwise, . - public override bool Equals(object? obj) => this.Equals(obj as AgentRequestMessageSource); - - /// - /// Returns the hash code for this instance. - /// - /// A 32-bit signed integer hash code. - public override int GetHashCode() => this.Value?.GetHashCode() ?? 0; - - /// - /// Determines whether two specified objects have the same value. - /// - /// The first to compare. - /// The second to compare. - /// if the value of is the same as the value of ; otherwise, . - public static bool operator ==(AgentRequestMessageSource? left, AgentRequestMessageSource? right) - { - if (left is null) - { - return right is null; - } - - return left.Equals(right); - } - - /// - /// Determines whether two specified objects have different values. - /// - /// The first to compare. - /// The second to compare. - /// if the value of is different from the value of ; otherwise, . - public static bool operator !=(AgentRequestMessageSource? left, AgentRequestMessageSource? right) => !(left == right); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs new file mode 100644 index 0000000000..de53c39419 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// An enumeration representing the source types of an agent request message. +/// +/// +/// Input messages for a specific agent run can originate from various sources. +/// This enumeration helps to identify whether a message came from outside the agent pipeline, +/// whether it was produced by middleware, or came from chat history. +/// +public sealed class AgentRequestMessageSourceType : IEquatable +{ + /// + /// Provides the key used in to store the source type of the agent request message. + /// + public static readonly string AdditionalPropertiesKey = "Agent.RequestMessageSourceType"; + + /// + /// Initializes a new instance of the class. + /// + /// The string value representing the source of the agent request message. + public AgentRequestMessageSourceType(string value) => this.Value = Throw.IfNullOrWhitespace(value); + + /// + /// Get the string value representing the source of the agent request message. + /// + public string Value { get; } + + /// + /// The message came from outside the agent pipeline (e.g., user input). + /// + public static AgentRequestMessageSourceType External { get; } = new AgentRequestMessageSourceType(nameof(External)); + + /// + /// The message was produced by middleware. + /// + public static AgentRequestMessageSourceType AIContextProvider { get; } = new AgentRequestMessageSourceType(nameof(AIContextProvider)); + + /// + /// The message came from chat history. + /// + public static AgentRequestMessageSourceType ChatHistory { get; } = new AgentRequestMessageSourceType(nameof(ChatHistory)); + + /// + /// Determines whether this instance and another specified object have the same value. + /// + /// The to compare to this instance. + /// if the value of the parameter is the same as the value of this instance; otherwise, . + public bool Equals(AgentRequestMessageSourceType? other) + { + if (other is null) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + return string.Equals(this.Value, other.Value, StringComparison.Ordinal); + } + + /// + /// Determines whether this instance and a specified object have the same value. + /// + /// The object to compare to this instance. + /// if is a and its value is the same as this instance; otherwise, . + public override bool Equals(object? obj) => this.Equals(obj as AgentRequestMessageSourceType); + + /// + /// Returns the hash code for this instance. + /// + /// A 32-bit signed integer hash code. + public override int GetHashCode() => this.Value?.GetHashCode() ?? 0; + + /// + /// Determines whether two specified objects have the same value. + /// + /// The first to compare. + /// The second to compare. + /// if the value of is the same as the value of ; otherwise, . + public static bool operator ==(AgentRequestMessageSourceType? left, AgentRequestMessageSourceType? right) + { + if (left is null) + { + return right is null; + } + + return left.Equals(right); + } + + /// + /// Determines whether two specified objects have different values. + /// + /// The first to compare. + /// The second to compare. + /// if the value of is different from the value of ; otherwise, . + public static bool operator !=(AgentRequestMessageSourceType? left, AgentRequestMessageSourceType? right) => !(left == right); +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index 2ef630cf1a..762630e1dd 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -37,6 +37,25 @@ namespace Microsoft.Agents.AI; /// public abstract class ChatHistoryProvider { + private readonly string _sourceName; + + /// + /// Initializes a new instance of the class. + /// + protected ChatHistoryProvider() + { + this._sourceName = this.GetType().FullName!; + } + + /// + /// Initializes a new instance of the class with the specified source name. + /// + /// The source name to stamp on for each messages produced by the . + protected ChatHistoryProvider(string sourceName) + { + this._sourceName = sourceName; + } + /// /// Called at the start of agent invocation to provide messages from the chat history as context for the next agent invocation. /// @@ -73,16 +92,20 @@ public async ValueTask> InvokingAsync(InvokingContext c return messages.Select(message => { if (message.AdditionalProperties != null - && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var source) - && source is AgentRequestMessageSource typedSource - && typedSource == AgentRequestMessageSource.ChatHistory) + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType) + && messageSourceType is AgentRequestMessageSourceType typedMessageSourceType + && typedMessageSourceType == AgentRequestMessageSourceType.ChatHistory + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource) + && messageSource is string typedMessageSource + && typedMessageSource == this._sourceName) { return message; } message = message.Clone(); message.AdditionalProperties ??= new(); - message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = AgentRequestMessageSource.ChatHistory; + message.AdditionalProperties[AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.ChatHistory; + message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = this._sourceName; return message; }); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs index 4cd8d570db..c2ff8bf3e5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs @@ -35,7 +35,7 @@ public static ChatHistoryProvider WithMessageFilters( /// /// Decorates the provided so that it does not add - /// messages with to chat history. + /// messages with to chat history. /// /// The to add the message filter to. /// A new instance that filters out messages so they do not get added. @@ -45,7 +45,7 @@ public static ChatHistoryProvider WithAIContextProviderMessageRemoval(this ChatH innerProvider: provider, invokedMessagesFilter: (ctx) => { - ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() != AgentRequestMessageSource.AIContextProvider); + ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() != AgentRequestMessageSourceType.AIContextProvider); return ctx; }); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs index caf472faa9..7e7dfdf882 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs @@ -13,15 +13,15 @@ public static class ChatMessageExtensions /// Gets the source of the provided in the context of messages passed into an agent run. /// /// The for which we need the source. - /// An value indicating the source of the . Defaults to if no explicit source is defined. - public static AgentRequestMessageSource GetAgentRequestMessageSource(this ChatMessage message) + /// An value indicating the source of the . Defaults to if no explicit source is defined. + public static AgentRequestMessageSourceType GetAgentRequestMessageSource(this ChatMessage message) { - if (message.AdditionalProperties?.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var source) is true && source is AgentRequestMessageSource typedSource) + if (message.AdditionalProperties?.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var source) is true && source is AgentRequestMessageSourceType typedSource) { return typedSource; } - return AgentRequestMessageSource.External; + return AgentRequestMessageSourceType.External; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index 6a4113bd42..8a0c016f07 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -138,7 +138,7 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext string queryText = string.Join( Environment.NewLine, context.RequestMessages - .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Where(m => !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); @@ -217,7 +217,7 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc // Persist request and response messages after invocation. await this.PersistMessagesAsync( context.RequestMessages - .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []), cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 7476aac267..c63e8ac682 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -189,7 +189,7 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext { // Get the text from the current request messages var requestText = string.Join("\n", context.RequestMessages - .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); @@ -245,7 +245,7 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); List> itemsToStore = context.RequestMessages - .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []) .Select(message => new Dictionary { diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index cff79c9e7a..ee87d4f00c 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -118,7 +118,7 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); var requestMessagesText = context.RequestMessages - .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); foreach (var messageText in this._recentMessagesText.Concat(requestMessagesText)) { @@ -182,7 +182,7 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati } var messagesText = context.RequestMessages - .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSource.External) + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []) .Where(m => this._recentMessageRolesIncluded.Contains(m.Role) && diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index 7ce9f06bbb..fe712168b1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -10,35 +12,155 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class AIContextProviderTests { + #region InvokingAsync Message Stamping Tests + + [Fact] + public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceAsync() + { + // Arrange + var provider = new TestAIContextProviderWithMessages(); + var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + ChatMessage message = aiContext.Messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestAIContextProviderWithMessages).FullName, source); + } + + [Fact] + public async Task InvokingAsync_WithCustomSourceName_StampsMessagesWithCustomSourceAsync() + { + // Arrange + const string CustomSourceName = "CustomContextSource"; + var provider = new TestAIContextProviderWithCustomSource(CustomSourceName); + var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + ChatMessage message = aiContext.Messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(CustomSourceName, source); + } + + [Fact] + public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() + { + // Arrange + var provider = new TestAIContextProviderWithPreStampedMessages(); + var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + ChatMessage message = aiContext.Messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestAIContextProviderWithPreStampedMessages).FullName, source); + } + + [Fact] + public async Task InvokingAsync_StampsMultipleMessagesAsync() + { + // Arrange + var provider = new TestAIContextProviderWithMultipleMessages(); + var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + List messageList = aiContext.Messages.ToList(); + Assert.Equal(3, messageList.Count); + + foreach (ChatMessage message in messageList) + { + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestAIContextProviderWithMultipleMessages).FullName, source); + } + } + + [Fact] + public async Task InvokingAsync_WithNullMessages_ReturnsContextWithoutStampingAsync() + { + // Arrange + var provider = new TestAIContextProvider(); + var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.Null(aiContext.Messages); + } + + #endregion + + #region Basic Tests + [Fact] public async Task InvokedAsync_ReturnsCompletedTaskAsync() { + // Arrange var provider = new TestAIContextProvider(); var messages = new ReadOnlyCollection([]); - var task = provider.InvokedAsync(new(messages)); + + // Act + ValueTask task = provider.InvokedAsync(new(messages)); + + // Assert Assert.Equal(default, task); } [Fact] public void Serialize_ReturnsEmptyElement() { + // Arrange var provider = new TestAIContextProvider(); + + // Act var actual = provider.Serialize(); + + // Assert Assert.Equal(default, actual); } [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { + // Act & Assert Assert.Throws(() => new AIContextProvider.InvokingContext(null!)); } [Fact] public void InvokedContext_Constructor_ThrowsForNullMessages() { + // Act & Assert Assert.Throws(() => new AIContextProvider.InvokedContext(null!)); } + #endregion + #region GetService Method Tests /// @@ -158,8 +280,58 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() private sealed class TestAIContextProvider : AIContextProvider { protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext()); + } + + private sealed class TestAIContextProviderWithMessages : AIContextProvider + { + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext + { + Messages = [new ChatMessage(ChatRole.System, "Context Message")] + }); + } + + private sealed class TestAIContextProviderWithCustomSource : AIContextProvider + { + public TestAIContextProviderWithCustomSource(string sourceName) : base(sourceName) { - return default; } + + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext + { + Messages = [new ChatMessage(ChatRole.System, "Context Message")] + }); + } + + private sealed class TestAIContextProviderWithPreStampedMessages : AIContextProvider + { + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var message = new ChatMessage(ChatRole.System, "Pre-stamped Message"); + message.AdditionalProperties = new AdditionalPropertiesDictionary + { + [AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.AIContextProvider, + [AgentRequestMessageSource.AdditionalPropertiesKey] = this.GetType().FullName! + }; + return new(new AIContext + { + Messages = [message] + }); + } + } + + private sealed class TestAIContextProviderWithMultipleMessages : AIContextProvider + { + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext + { + Messages = [ + new ChatMessage(ChatRole.System, "Message 1"), + new ChatMessage(ChatRole.User, "Message 2"), + new ChatMessage(ChatRole.Assistant, "Message 3") + ] + }); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs index faf2ab6fc0..1e22e75fcd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs @@ -5,7 +5,7 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// -/// Contains tests for the class. +/// Contains tests for the class. /// public sealed class AgentRequestMessageSourceTests { @@ -18,7 +18,7 @@ public void Constructor_WithValue_SetsValueProperty() const string ExpectedValue = "CustomSource"; // Act - AgentRequestMessageSource source = new(ExpectedValue); + AgentRequestMessageSourceType source = new(ExpectedValue); // Assert Assert.Equal(ExpectedValue, source.Value); @@ -28,14 +28,14 @@ public void Constructor_WithValue_SetsValueProperty() public void Constructor_WithNullValue_Throws() { // Act & Assert - Assert.Throws(() => new AgentRequestMessageSource(null!)); + Assert.Throws(() => new AgentRequestMessageSourceType(null!)); } [Fact] public void Constructor_WithEmptyValue_Throws() { // Act & Assert - Assert.Throws(() => new AgentRequestMessageSource(string.Empty)); + Assert.Throws(() => new AgentRequestMessageSourceType(string.Empty)); } #endregion @@ -46,7 +46,7 @@ public void Constructor_WithEmptyValue_Throws() public void External_ReturnsInstanceWithExternalValue() { // Arrange & Act - AgentRequestMessageSource source = AgentRequestMessageSource.External; + AgentRequestMessageSourceType source = AgentRequestMessageSourceType.External; // Assert Assert.NotNull(source); @@ -57,7 +57,7 @@ public void External_ReturnsInstanceWithExternalValue() public void AIContextProvider_ReturnsInstanceWithAIContextProviderValue() { // Arrange & Act - AgentRequestMessageSource source = AgentRequestMessageSource.AIContextProvider; + AgentRequestMessageSourceType source = AgentRequestMessageSourceType.AIContextProvider; // Assert Assert.NotNull(source); @@ -68,7 +68,7 @@ public void AIContextProvider_ReturnsInstanceWithAIContextProviderValue() public void ChatHistory_ReturnsInstanceWithChatHistoryValue() { // Arrange & Act - AgentRequestMessageSource source = AgentRequestMessageSource.ChatHistory; + AgentRequestMessageSourceType source = AgentRequestMessageSourceType.ChatHistory; // Assert Assert.NotNull(source); @@ -79,22 +79,22 @@ public void ChatHistory_ReturnsInstanceWithChatHistoryValue() public void AdditionalPropertiesKey_ReturnsExpectedValue() { // Arrange & Act - string key = AgentRequestMessageSource.AdditionalPropertiesKey; + string key = AgentRequestMessageSourceType.AdditionalPropertiesKey; // Assert - Assert.Equal("Agent.RequestMessageSource", key); + Assert.Equal("Agent.RequestMessageSourceType", key); } [Fact] public void StaticProperties_ReturnSameInstanceOnMultipleCalls() { // Arrange & Act - AgentRequestMessageSource external1 = AgentRequestMessageSource.External; - AgentRequestMessageSource external2 = AgentRequestMessageSource.External; - AgentRequestMessageSource aiContextProvider1 = AgentRequestMessageSource.AIContextProvider; - AgentRequestMessageSource aiContextProvider2 = AgentRequestMessageSource.AIContextProvider; - AgentRequestMessageSource chatHistory1 = AgentRequestMessageSource.ChatHistory; - AgentRequestMessageSource chatHistory2 = AgentRequestMessageSource.ChatHistory; + AgentRequestMessageSourceType external1 = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType external2 = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType aiContextProvider1 = AgentRequestMessageSourceType.AIContextProvider; + AgentRequestMessageSourceType aiContextProvider2 = AgentRequestMessageSourceType.AIContextProvider; + AgentRequestMessageSourceType chatHistory1 = AgentRequestMessageSourceType.ChatHistory; + AgentRequestMessageSourceType chatHistory2 = AgentRequestMessageSourceType.ChatHistory; // Assert Assert.Same(external1, external2); @@ -110,7 +110,7 @@ public void StaticProperties_ReturnSameInstanceOnMultipleCalls() public void Equals_WithSameInstance_ReturnsTrue() { // Arrange - AgentRequestMessageSource source = new("Test"); + AgentRequestMessageSourceType source = new("Test"); // Act bool result = source.Equals(source); @@ -123,8 +123,8 @@ public void Equals_WithSameInstance_ReturnsTrue() public void Equals_WithEqualValue_ReturnsTrue() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); // Act bool result = source1.Equals(source2); @@ -137,8 +137,8 @@ public void Equals_WithEqualValue_ReturnsTrue() public void Equals_WithDifferentValue_ReturnsFalse() { // Arrange - AgentRequestMessageSource source1 = new("Test1"); - AgentRequestMessageSource source2 = new("Test2"); + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); // Act bool result = source1.Equals(source2); @@ -151,7 +151,7 @@ public void Equals_WithDifferentValue_ReturnsFalse() public void Equals_WithNull_ReturnsFalse() { // Arrange - AgentRequestMessageSource source = new("Test"); + AgentRequestMessageSourceType source = new("Test"); // Act bool result = source.Equals(null); @@ -164,8 +164,8 @@ public void Equals_WithNull_ReturnsFalse() public void Equals_WithDifferentCase_ReturnsFalse() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource source2 = new("test"); + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("test"); // Act bool result = source1.Equals(source2); @@ -178,8 +178,8 @@ public void Equals_WithDifferentCase_ReturnsFalse() public void Equals_StaticExternalWithNewInstanceHavingSameValue_ReturnsTrue() { // Arrange - AgentRequestMessageSource external = AgentRequestMessageSource.External; - AgentRequestMessageSource newExternal = new("External"); + AgentRequestMessageSourceType external = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType newExternal = new("External"); // Act bool result = external.Equals(newExternal); @@ -196,8 +196,8 @@ public void Equals_StaticExternalWithNewInstanceHavingSameValue_ReturnsTrue() public void ObjectEquals_WithEqualAgentRequestMessageSource_ReturnsTrue() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - object source2 = new AgentRequestMessageSource("Test"); + AgentRequestMessageSourceType source1 = new("Test"); + object source2 = new AgentRequestMessageSourceType("Test"); // Act bool result = source1.Equals(source2); @@ -210,7 +210,7 @@ public void ObjectEquals_WithEqualAgentRequestMessageSource_ReturnsTrue() public void ObjectEquals_WithDifferentType_ReturnsFalse() { // Arrange - AgentRequestMessageSource source = new("Test"); + AgentRequestMessageSourceType source = new("Test"); object other = "Test"; // Act @@ -224,7 +224,7 @@ public void ObjectEquals_WithDifferentType_ReturnsFalse() public void ObjectEquals_WithNullObject_ReturnsFalse() { // Arrange - AgentRequestMessageSource source = new("Test"); + AgentRequestMessageSourceType source = new("Test"); object? other = null; // Act @@ -242,8 +242,8 @@ public void ObjectEquals_WithNullObject_ReturnsFalse() public void GetHashCode_WithSameValue_ReturnsSameHashCode() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); // Act int hashCode1 = source1.GetHashCode(); @@ -257,8 +257,8 @@ public void GetHashCode_WithSameValue_ReturnsSameHashCode() public void GetHashCode_WithDifferentValue_ReturnsDifferentHashCode() { // Arrange - AgentRequestMessageSource source1 = new("Test1"); - AgentRequestMessageSource source2 = new("Test2"); + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); // Act int hashCode1 = source1.GetHashCode(); @@ -272,8 +272,8 @@ public void GetHashCode_WithDifferentValue_ReturnsDifferentHashCode() public void GetHashCode_ConsistentWithEquals() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); // Act & Assert // If two objects are equal, they must have the same hash code @@ -289,8 +289,8 @@ public void GetHashCode_ConsistentWithEquals() public void EqualityOperator_WithEqualValues_ReturnsTrue() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); // Act bool result = source1 == source2; @@ -303,8 +303,8 @@ public void EqualityOperator_WithEqualValues_ReturnsTrue() public void EqualityOperator_WithDifferentValues_ReturnsFalse() { // Arrange - AgentRequestMessageSource source1 = new("Test1"); - AgentRequestMessageSource source2 = new("Test2"); + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); // Act bool result = source1 == source2; @@ -317,8 +317,8 @@ public void EqualityOperator_WithDifferentValues_ReturnsFalse() public void EqualityOperator_WithBothNull_ReturnsTrue() { // Arrange - AgentRequestMessageSource? source1 = null; - AgentRequestMessageSource? source2 = null; + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType? source2 = null; // Act bool result = source1 == source2; @@ -331,8 +331,8 @@ public void EqualityOperator_WithBothNull_ReturnsTrue() public void EqualityOperator_WithLeftNull_ReturnsFalse() { // Arrange - AgentRequestMessageSource? source1 = null; - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType source2 = new("Test"); // Act bool result = source1 == source2; @@ -345,8 +345,8 @@ public void EqualityOperator_WithLeftNull_ReturnsFalse() public void EqualityOperator_WithRightNull_ReturnsFalse() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource? source2 = null; + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType? source2 = null; // Act bool result = source1 == source2; @@ -359,8 +359,8 @@ public void EqualityOperator_WithRightNull_ReturnsFalse() public void EqualityOperator_WithStaticInstances_ReturnsTrue() { // Arrange - AgentRequestMessageSource external1 = AgentRequestMessageSource.External; - AgentRequestMessageSource external2 = AgentRequestMessageSource.External; + AgentRequestMessageSourceType external1 = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType external2 = AgentRequestMessageSourceType.External; // Act bool result = external1 == external2; @@ -373,8 +373,8 @@ public void EqualityOperator_WithStaticInstances_ReturnsTrue() public void EqualityOperator_StaticWithNewInstanceHavingSameValue_ReturnsTrue() { // Arrange - AgentRequestMessageSource external = AgentRequestMessageSource.External; - AgentRequestMessageSource newExternal = new("External"); + AgentRequestMessageSourceType external = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType newExternal = new("External"); // Act bool result = external == newExternal; @@ -391,8 +391,8 @@ public void EqualityOperator_StaticWithNewInstanceHavingSameValue_ReturnsTrue() public void InequalityOperator_WithEqualValues_ReturnsFalse() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); // Act bool result = source1 != source2; @@ -405,8 +405,8 @@ public void InequalityOperator_WithEqualValues_ReturnsFalse() public void InequalityOperator_WithDifferentValues_ReturnsTrue() { // Arrange - AgentRequestMessageSource source1 = new("Test1"); - AgentRequestMessageSource source2 = new("Test2"); + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); // Act bool result = source1 != source2; @@ -419,8 +419,8 @@ public void InequalityOperator_WithDifferentValues_ReturnsTrue() public void InequalityOperator_WithBothNull_ReturnsFalse() { // Arrange - AgentRequestMessageSource? source1 = null; - AgentRequestMessageSource? source2 = null; + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType? source2 = null; // Act bool result = source1 != source2; @@ -433,8 +433,8 @@ public void InequalityOperator_WithBothNull_ReturnsFalse() public void InequalityOperator_WithLeftNull_ReturnsTrue() { // Arrange - AgentRequestMessageSource? source1 = null; - AgentRequestMessageSource source2 = new("Test"); + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType source2 = new("Test"); // Act bool result = source1 != source2; @@ -447,8 +447,8 @@ public void InequalityOperator_WithLeftNull_ReturnsTrue() public void InequalityOperator_WithRightNull_ReturnsTrue() { // Arrange - AgentRequestMessageSource source1 = new("Test"); - AgentRequestMessageSource? source2 = null; + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType? source2 = null; // Act bool result = source1 != source2; @@ -461,8 +461,8 @@ public void InequalityOperator_WithRightNull_ReturnsTrue() public void InequalityOperator_DifferentStaticInstances_ReturnsTrue() { // Arrange - AgentRequestMessageSource external = AgentRequestMessageSource.External; - AgentRequestMessageSource chatHistory = AgentRequestMessageSource.ChatHistory; + AgentRequestMessageSourceType external = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType chatHistory = AgentRequestMessageSourceType.ChatHistory; // Act bool result = external != chatHistory; @@ -479,10 +479,10 @@ public void InequalityOperator_DifferentStaticInstances_ReturnsTrue() public void IEquatable_ImplementedCorrectly() { // Arrange - AgentRequestMessageSource source = new("Test"); + AgentRequestMessageSourceType source = new("Test"); // Act & Assert - Assert.IsAssignableFrom>(source); + Assert.IsAssignableFrom>(source); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs index db366a4b9b..de4f4e29c4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs @@ -61,7 +61,7 @@ public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() Mock providerMock = new(); List requestMessages = [ - new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, new(ChatRole.User, "Hello") ]; ChatHistoryProvider.InvokedContext context = new(requestMessages) @@ -111,9 +111,9 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe Mock providerMock = new(); List requestMessages = [ - new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, new(ChatRole.User, "Hello"), - new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.AIContextProvider } } } + new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.AIContextProvider } } } ]; ChatHistoryProvider.InvokedContext context = new(requestMessages); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 6998a91cf2..1c1fca9621 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -154,7 +154,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() var innerProviderMock = new Mock(); List requestMessages = [ - new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, new(ChatRole.User, "Hello"), ]; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; @@ -173,7 +173,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() // Filter that modifies the context ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) { - var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() == AgentRequestMessageSource.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); + var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages) { ResponseMessages = ctx.ResponseMessages, diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 4f981b9692..d6a8bc446e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -14,6 +15,92 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class ChatHistoryProviderTests { + #region InvokingAsync Message Stamping Tests + + [Fact] + public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceAsync() + { + // Arrange + var provider = new TestChatHistoryProvider(); + var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + ChatMessage message = messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestChatHistoryProvider).FullName, source); + } + + [Fact] + public async Task InvokingAsync_WithCustomSourceName_StampsMessagesWithCustomSourceAsync() + { + // Arrange + const string CustomSourceName = "CustomHistorySource"; + var provider = new TestChatHistoryProviderWithCustomSource(CustomSourceName); + var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + ChatMessage message = messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(CustomSourceName, source); + } + + [Fact] + public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() + { + // Arrange + var provider = new TestChatHistoryProviderWithPreStampedMessages(); + var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + ChatMessage message = messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestChatHistoryProviderWithPreStampedMessages).FullName, source); + } + + [Fact] + public async Task InvokingAsync_StampsMultipleMessagesAsync() + { + // Arrange + var provider = new TestChatHistoryProviderWithMultipleMessages(); + var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + List messageList = messages.ToList(); + Assert.Equal(3, messageList.Count); + + foreach (ChatMessage message in messageList) + { + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestChatHistoryProviderWithMultipleMessages).FullName, source); + } + } + + #endregion + #region GetService Method Tests [Fact] @@ -79,7 +166,59 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() private sealed class TestChatHistoryProvider : ChatHistoryProvider { protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(Array.Empty()); + => new([new ChatMessage(ChatRole.User, "Test Message")]); + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + => default; + } + + private sealed class TestChatHistoryProviderWithCustomSource : ChatHistoryProvider + { + public TestChatHistoryProviderWithCustomSource(string sourceName) : base(sourceName) + { + } + + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new([new ChatMessage(ChatRole.User, "Test Message")]); + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + => default; + } + + private sealed class TestChatHistoryProviderWithPreStampedMessages : ChatHistoryProvider + { + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var message = new ChatMessage(ChatRole.User, "Pre-stamped Message"); + message.AdditionalProperties = new AdditionalPropertiesDictionary + { + [AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.ChatHistory, + [AgentRequestMessageSource.AdditionalPropertiesKey] = this.GetType().FullName! + }; + return new([message]); + } + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + => default; + } + + private sealed class TestChatHistoryProviderWithMultipleMessages : ChatHistoryProvider + { + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new([ + new ChatMessage(ChatRole.User, "Message 1"), + new ChatMessage(ChatRole.Assistant, "Message 2"), + new ChatMessage(ChatRole.User, "Message 3") + ]); protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs index 269ef03ee1..f389c567d2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs @@ -18,10 +18,10 @@ public void GetAgentRequestMessageSource_WithNoAdditionalProperties_ReturnsExter ChatMessage message = new(ChatRole.User, "Hello"); // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.External, result); + Assert.Equal(AgentRequestMessageSourceType.External, result); } [Fact] @@ -34,10 +34,10 @@ public void GetAgentRequestMessageSource_WithNullAdditionalProperties_ReturnsExt }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.External, result); + Assert.Equal(AgentRequestMessageSourceType.External, result); } [Fact] @@ -50,10 +50,10 @@ public void GetAgentRequestMessageSource_WithEmptyAdditionalProperties_ReturnsEx }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.External, result); + Assert.Equal(AgentRequestMessageSourceType.External, result); } [Fact] @@ -64,15 +64,15 @@ public void GetAgentRequestMessageSource_WithExternalSource_ReturnsExternal() { AdditionalProperties = new AdditionalPropertiesDictionary { - { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.External } + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.External } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.External, result); + Assert.Equal(AgentRequestMessageSourceType.External, result); } [Fact] @@ -83,15 +83,15 @@ public void GetAgentRequestMessageSource_WithAIContextProviderSource_ReturnsAICo { AdditionalProperties = new AdditionalPropertiesDictionary { - { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.AIContextProvider } + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.AIContextProvider } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.AIContextProvider, result); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, result); } [Fact] @@ -102,32 +102,32 @@ public void GetAgentRequestMessageSource_WithChatHistorySource_ReturnsChatHistor { AdditionalProperties = new AdditionalPropertiesDictionary { - { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.ChatHistory, result); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, result); } [Fact] public void GetAgentRequestMessageSource_WithCustomSource_ReturnsCustomSource() { // Arrange - AgentRequestMessageSource customSource = new("CustomSource"); + AgentRequestMessageSourceType customSource = new("CustomSource"); ChatMessage message = new(ChatRole.User, "Hello") { AdditionalProperties = new AdditionalPropertiesDictionary { - { AgentRequestMessageSource.AdditionalPropertiesKey, customSource } + { AgentRequestMessageSourceType.AdditionalPropertiesKey, customSource } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert Assert.Equal(customSource, result); @@ -142,15 +142,15 @@ public void GetAgentRequestMessageSource_WithWrongKeyType_ReturnsExternal() { AdditionalProperties = new AdditionalPropertiesDictionary { - { AgentRequestMessageSource.AdditionalPropertiesKey, "NotAnAgentRequestMessageSource" } + { AgentRequestMessageSourceType.AdditionalPropertiesKey, "NotAnAgentRequestMessageSource" } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.External, result); + Assert.Equal(AgentRequestMessageSourceType.External, result); } [Fact] @@ -161,15 +161,15 @@ public void GetAgentRequestMessageSource_WithNullValue_ReturnsExternal() { AdditionalProperties = new AdditionalPropertiesDictionary { - { AgentRequestMessageSource.AdditionalPropertiesKey, null! } + { AgentRequestMessageSourceType.AdditionalPropertiesKey, null! } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.External, result); + Assert.Equal(AgentRequestMessageSourceType.External, result); } [Fact] @@ -181,16 +181,16 @@ public void GetAgentRequestMessageSource_WithMultipleProperties_ReturnsCorrectSo AdditionalProperties = new AdditionalPropertiesDictionary { { "OtherProperty", "SomeValue" }, - { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory }, + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory }, { "AnotherProperty", 123 } } }; // Act - AgentRequestMessageSource result = message.GetAgentRequestMessageSource(); + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); // Assert - Assert.Equal(AgentRequestMessageSource.ChatHistory, result); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, result); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index 40ccee086a..b11f047e8e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -52,7 +52,7 @@ public async Task InvokedAsyncAddsMessagesAsync() var requestMessages = new List { new(ChatRole.User, "Hello"), - new(ChatRole.System, "additional context") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.ChatHistory } } }, + new(ChatRole.System, "additional context") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, }; var responseMessages = new List { diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index 88dcca229f..f38b7b2794 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -283,7 +283,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.User, "First message"), new ChatMessage(ChatRole.Assistant, "Second message"), new ChatMessage(ChatRole.User, "Third message"), - new ChatMessage(ChatRole.System, "System context message") { AdditionalProperties = new() { { AgentRequestMessageSource.AdditionalPropertiesKey, AgentRequestMessageSource.AIContextProvider } } } + new ChatMessage(ChatRole.System, "System context message") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.AIContextProvider } } } }; var responseMessages = new[] { From 2c64d213dd5a2a7a7c338a794916149266bc5c92 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 4 Feb 2026 16:11:10 +0000 Subject: [PATCH 3/5] Address PR comments. --- .../AgentRequestMessageSourceType.cs | 4 ++-- .../Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs | 2 +- ...geSourceTests.cs => AgentRequestMessageSourceTypeTests.cs} | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/{AgentRequestMessageSourceTests.cs => AgentRequestMessageSourceTypeTests.cs} (99%) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs index de53c39419..1cca747906 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs @@ -7,11 +7,11 @@ namespace Microsoft.Agents.AI; /// -/// An enumeration representing the source types of an agent request message. +/// Represents the source of an agent request message. /// /// /// Input messages for a specific agent run can originate from various sources. -/// This enumeration helps to identify whether a message came from outside the agent pipeline, +/// This type helps to identify whether a message came from outside the agent pipeline, /// whether it was produced by middleware, or came from chat history. /// public sealed class AgentRequestMessageSourceType : IEquatable diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs index 7e7dfdf882..01edcb4eff 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs @@ -5,7 +5,7 @@ namespace Microsoft.Agents.AI; /// -/// Conatins extension methods for +/// Contains extension methods for /// public static class ChatMessageExtensions { diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs similarity index 99% rename from dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs rename to dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs index 1e22e75fcd..f6149092f3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs @@ -7,7 +7,7 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// /// Contains tests for the class. /// -public sealed class AgentRequestMessageSourceTests +public sealed class AgentRequestMessageSourceTypeTests { #region Constructor Tests From 8cfdc09ccb2f003b249cd2dfd192c98b9e2a7245 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 14:58:29 +0000 Subject: [PATCH 4/5] Add merge fixes --- .../ChatHistoryProvider.cs | 2 +- .../ChatClient/ChatClientAgent.cs | 34 ++++++----- .../AIContextProviderTests.cs | 43 +++++--------- .../ChatHistoryProviderMessageFilterTests.cs | 2 +- .../ChatHistoryProviderTests.cs | 56 +++++-------------- .../InMemoryChatHistoryProviderTests.cs | 2 +- ...tClientAgent_ChatHistoryManagementTests.cs | 4 +- .../Data/TextSearchProviderTests.cs | 10 ++-- 8 files changed, 57 insertions(+), 96 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index b042e77673..c4dc6dc9bd 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -298,7 +298,7 @@ public sealed class InvokedContext public InvokedContext( AIAgent agent, AgentSession? session, - IEnumerable requestMessages,) + IEnumerable requestMessages) { this.Agent = Throw.IfNull(agent); this.Session = session; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 8203c06941..5878d877b2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -206,6 +206,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA (ChatClientAgentSession safeSession, ChatOptions? chatOptions, + List inputMessagesForProviders, List inputMessagesForChatClient, ChatClientAgentContinuationToken? continuationToken) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -225,12 +226,12 @@ protected override async IAsyncEnumerable RunCoreStreamingA try { // Using the enumerator to ensure we consider the case where no updates are returned for notification. - responseUpdatesEnumerator = chatClient.GetStreamingResponseAsync(inputMessagesForChatClient, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken); + responseUpdatesEnumerator = chatClient.GetStreamingResponseAsync(inputMessagesForProviders, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken); } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -244,8 +245,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -271,8 +272,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); throw; } } @@ -284,10 +285,10 @@ protected override async IAsyncEnumerable RunCoreStreamingA await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -431,6 +432,7 @@ private async Task RunCoreAsync inputMessagesForProviders, List inputMessagesForChatClient, ChatClientAgentContinuationToken? _) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -451,8 +453,8 @@ private async Task RunCoreAsync RunCoreAsync inputMessagesForProviders, List InputMessagesForChatClient, ChatClientAgentContinuationToken? ContinuationToken )> PrepareSessionAndMessagesAsync( @@ -706,6 +709,7 @@ private async Task throw new InvalidOperationException("Input messages are not allowed when continuing a background response using a continuation token."); } + List inputMessagesForProviders = []; List inputMessagesForChatClient = []; // Populate the session messages only if we are not continuing an existing response as it's not allowed @@ -722,6 +726,7 @@ private async Task } // Add the input messages before getting context from AIContextProvider. + inputMessagesForProviders.AddRange(inputMessages); inputMessagesForChatClient.AddRange(inputMessages); // If we have an AIContextProvider, we should get context from it, and update our @@ -732,6 +737,7 @@ private async Task var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { + inputMessagesForProviders.AddRange(aiContext.Messages); inputMessagesForChatClient.AddRange(aiContext.Messages); } @@ -771,7 +777,7 @@ private async Task chatOptions.ConversationId = typedSession.ConversationId; } - return (typedSession, chatOptions, inputMessagesForChatClient, continuationToken); + return (typedSession, chatOptions, inputMessagesForProviders, inputMessagesForChatClient, continuationToken); } private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index 9f82312b74..44d1be2e74 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -23,7 +23,7 @@ public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceAsync() { // Arrange var provider = new TestAIContextProviderWithMessages(); - var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act AIContext aiContext = await provider.InvokingAsync(context); @@ -44,7 +44,7 @@ public async Task InvokingAsync_WithCustomSourceName_StampsMessagesWithCustomSou // Arrange const string CustomSourceName = "CustomContextSource"; var provider = new TestAIContextProviderWithCustomSource(CustomSourceName); - var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act AIContext aiContext = await provider.InvokingAsync(context); @@ -64,7 +64,7 @@ public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() { // Arrange var provider = new TestAIContextProviderWithPreStampedMessages(); - var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act AIContext aiContext = await provider.InvokingAsync(context); @@ -84,7 +84,7 @@ public async Task InvokingAsync_StampsMultipleMessagesAsync() { // Arrange var provider = new TestAIContextProviderWithMultipleMessages(); - var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act AIContext aiContext = await provider.InvokingAsync(context); @@ -109,7 +109,7 @@ public async Task InvokingAsync_WithNullMessages_ReturnsContextWithoutStampingAs { // Arrange var provider = new TestAIContextProvider(); - var context = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act AIContext aiContext = await provider.InvokingAsync(context); @@ -130,7 +130,7 @@ public async Task InvokedAsync_ReturnsCompletedTaskAsync() var messages = new ReadOnlyCollection([]); // Act - ValueTask task = provider.InvokedAsync(s_mockAgent, s_mockSession, new(messages)); + ValueTask task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Assert Assert.Equal(default, task); @@ -367,7 +367,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -379,7 +379,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -388,28 +388,13 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } - [Fact] - public void InvokedContext_AIContextProviderMessages_Roundtrips() - { - // Arrange - var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); - - // Act - context.AIContextProviderMessages = aiContextMessages; - - // Assert - Assert.Same(aiContextMessages, context.AIContextProviderMessages); - } - [Fact] public void InvokedContext_ResponseMessages_Roundtrips() { // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.ResponseMessages = responseMessages; @@ -424,7 +409,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var exception = new InvalidOperationException("Test exception"); - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.InvokeException = exception; @@ -440,7 +425,7 @@ public void InvokedContext_Agent_ReturnsConstructorValue() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -453,7 +438,7 @@ public void InvokedContext_Session_ReturnsConstructorValue() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockSession, context.Session); @@ -466,7 +451,7 @@ public void InvokedContext_Session_CanBeNull() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages); // Assert Assert.Null(context.Session); @@ -479,7 +464,7 @@ public void InvokedContext_Constructor_ThrowsForNullAgent() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act & Assert - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null)); + Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages)); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 006e890c19..5b48d025be 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -161,7 +161,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() new(ChatRole.User, "Hello"), ]; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, s_mockAgent, s_mockSession, requestMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 0dea84d5ba..e158b159ca 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -26,7 +26,7 @@ public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceAsync() { // Arrange var provider = new TestChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act IEnumerable messages = await provider.InvokingAsync(context); @@ -46,7 +46,7 @@ public async Task InvokingAsync_WithCustomSourceName_StampsMessagesWithCustomSou // Arrange const string CustomSourceName = "CustomHistorySource"; var provider = new TestChatHistoryProviderWithCustomSource(CustomSourceName); - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act IEnumerable messages = await provider.InvokingAsync(context); @@ -65,7 +65,7 @@ public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() { // Arrange var provider = new TestChatHistoryProviderWithPreStampedMessages(); - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act IEnumerable messages = await provider.InvokingAsync(context); @@ -84,7 +84,7 @@ public async Task InvokingAsync_StampsMultipleMessagesAsync() { // Arrange var provider = new TestChatHistoryProviderWithMultipleMessages(); - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Request")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); // Act IEnumerable messages = await provider.InvokingAsync(context); @@ -259,7 +259,7 @@ public void InvokingContext_Constructor_ThrowsForNullAgent() public void InvokedContext_Constructor_ThrowsForNullRequestMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!)); } [Fact] @@ -267,7 +267,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -279,7 +279,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -288,43 +288,13 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } - [Fact] - public void InvokedContext_ChatHistoryProviderMessages_SetterRoundtrips() - { - // Arrange - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var newProviderMessages = new List { new(ChatRole.System, "System message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); - - // Act - context.ChatHistoryProviderMessages = newProviderMessages; - - // Assert - Assert.Same(newProviderMessages, context.ChatHistoryProviderMessages); - } - - [Fact] - public void InvokedContext_AIContextProviderMessages_Roundtrips() - { - // Arrange - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); - - // Act - context.AIContextProviderMessages = aiContextMessages; - - // Assert - Assert.Same(aiContextMessages, context.AIContextProviderMessages); - } - [Fact] public void InvokedContext_ResponseMessages_Roundtrips() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.ResponseMessages = responseMessages; @@ -339,7 +309,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var exception = new InvalidOperationException("Test exception"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.InvokeException = exception; @@ -355,7 +325,7 @@ public void InvokedContext_Agent_ReturnsConstructorValue() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -368,7 +338,7 @@ public void InvokedContext_Session_ReturnsConstructorValue() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockSession, context.Session); @@ -381,7 +351,7 @@ public void InvokedContext_Session_CanBeNull() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages); // Assert Assert.Null(context.Session); @@ -394,7 +364,7 @@ public void InvokedContext_Constructor_ThrowsForNullAgent() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages)); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index 5a70602eca..75232073a6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -626,7 +626,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index 83ee2a8983..4de8f01f8e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -222,7 +222,7 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve mockChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Once(), - ItExpr.Is(x => x.RequestMessages.Count() == 2 && x.ResponseMessages!.Count() == 1), + ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), ItExpr.IsAny()); mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } @@ -366,7 +366,7 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi mockOverrideChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Once(), - ItExpr.Is(x => x.RequestMessages.Count() == 2 && x.ResponseMessages!.Count() == 1), + ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), ItExpr.IsAny()); mockFactoryChatHistoryProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 466db422aa..ec8dda3c45 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -424,7 +424,7 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc // First memory update (A,B) await provider.InvokedAsync(new( s_mockAgent, - S_mockSession, + s_mockSession, [ new ChatMessage(ChatRole.User, "A"), new ChatMessage(ChatRole.Assistant, "B"), @@ -475,7 +475,7 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(s_mockAgent, s_mockSession, new(initialMessages)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, @@ -533,7 +533,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(s_mockAgent, s_mockSession, new(messages)); // Populate recent memory. + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Populate recent memory. var state = provider.Serialize(); // Assert @@ -562,7 +562,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(s_mockAgent, s_mockSession, new(messages)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Act var state = provider.Serialize(); @@ -603,7 +603,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn new ChatMessage(ChatRole.Assistant, "L4"), new ChatMessage(ChatRole.User, "L5"), }; - await initialProvider.InvokedAsync(s_mockAgent, s_mockSession, new(messages)); + await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); var state = initialProvider.Serialize(); string? capturedInput = null; From 460bd6c9cd6da88aabf6c69b7a19218a5f470ec6 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:00:26 +0000 Subject: [PATCH 5/5] Address PR comments --- .../src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs | 2 ++ .../src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index 9c472cee85..a4b606e6a1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -79,9 +79,11 @@ public async ValueTask InvokingAsync(InvokingContext context, Cancell aiContext.Messages = aiContext.Messages.Select(message => { if (message.AdditionalProperties != null + // Check if the message was already tagged with this provider's source type && message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType) && messageSourceType is AgentRequestMessageSourceType typedMessageSourceType && typedMessageSourceType == AgentRequestMessageSourceType.AIContextProvider + // Check if the message was already tagged with this provider's source && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource) && messageSource is string typedMessageSource && typedMessageSource == this._sourceName) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index c4dc6dc9bd..f49c5d46a7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -92,9 +92,11 @@ public async ValueTask> InvokingAsync(InvokingContext c return messages.Select(message => { if (message.AdditionalProperties != null + // Check if the message was already tagged with this provider's source type && message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType) && messageSourceType is AgentRequestMessageSourceType typedMessageSourceType && typedMessageSourceType == AgentRequestMessageSourceType.ChatHistory + // Check if the message was already tagged with this provider's source && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource) && messageSource is string typedMessageSource && typedMessageSource == this._sourceName)