Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ protected override async Task<AgentResponse> RunCoreAsync(IEnumerable<ChatMessag
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.Name).ToList();

// Notify the session of the input and output messages.
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages)
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages)
{
ResponseMessages = responseMessages
};
Expand Down Expand Up @@ -94,7 +94,7 @@ protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingA
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.Name).ToList();

// Notify the session of the input and output messages.
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages)
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages)
{
ResponseMessages = responseMessages
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public UserInfoMemory(IChatClient chatClient, JsonElement serializedState, JsonS

public UserInfo UserInfo { get; set; }

public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
// Try and extract the user name and age from the message if we don't have it already and it's a user message.
if ((this.UserInfo.UserName is null || this.UserInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User))
Expand All @@ -122,7 +122,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio
}
}

public override ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
StringBuilder instructions = new();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public VectorChatHistoryProvider(VectorStore vectorStore, JsonElement serialized

public string? SessionDbKey { get; private set; }

public override async ValueTask<IEnumerable<ChatMessage>> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);
Expand All @@ -107,7 +107,7 @@ public override async ValueTask<IEnumerable<ChatMessage>> InvokingAsync(Invoking
return messages;
}

public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
// Don't store messages if the request failed.
if (context.InvokeException is not null)
Expand All @@ -122,7 +122,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio

// Add both request and response messages to the store
// Optionally messages produced by the AIContextProvider can also be persisted (not shown).
var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []);
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);

await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public TodoListAIContextProvider(JsonElement jsonElement, JsonSerializerOptions?
}
}

public override ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
StringBuilder outputMessageBuilder = new();
outputMessageBuilder.AppendLine("Your todo list contains the following items:");
Expand Down Expand Up @@ -132,7 +132,7 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio
/// </summary>
internal sealed class CalendarSearchAIContextProvider(Func<Task<string[]>> loadNextThreeCalendarEvents) : AIContextProvider
{
public override async ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var events = await loadNextThreeCalendarEvents();

Expand Down Expand Up @@ -179,7 +179,7 @@ public AggregatingAIContextProvider(ProviderFactory[] providerFactories, JsonEle
.ToList();
}

public override async ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
// Invoke all the sub providers.
var tasks = this._providers.Select(provider => provider.InvokingAsync(context, cancellationToken).AsTask());
Expand Down
116 changes: 99 additions & 17 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -31,6 +32,25 @@ namespace Microsoft.Agents.AI;
/// </remarks>
public abstract class AIContextProvider
{
private readonly string _sourceName;

/// <summary>
/// Initializes a new instance of the <see cref="AIContextProvider"/> class.
/// </summary>
protected AIContextProvider()
{
this._sourceName = this.GetType().FullName!;
}

/// <summary>
/// Initializes a new instance of the <see cref="AIContextProvider"/> class with the specified source name.
/// </summary>
/// <param name="sourceName">The source name to stamp on <see cref="ChatMessage.AdditionalProperties"/> for each messages produced by the <see cref="AIContextProvider"/>.</param>
protected AIContextProvider(string sourceName)
{
this._sourceName = sourceName;
}

/// <summary>
/// Called at the start of agent invocation to provide additional context.
/// </summary>
Expand All @@ -48,7 +68,81 @@ public abstract class AIContextProvider
/// </list>
/// </para>
/// </remarks>
public abstract ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default);
public async ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var aiContext = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false);
if (aiContext.Messages is null)
{
return aiContext;
}

aiContext.Messages = aiContext.Messages.Select(message =>
{
if (message.AdditionalProperties != null
// Check if the message was already tagged with this provider's source type
&& message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType)
&& messageSourceType is AgentRequestMessageSourceType typedMessageSourceType
&& typedMessageSourceType == AgentRequestMessageSourceType.AIContextProvider
// Check if the message was already tagged with this provider's source
&& message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource)
&& messageSource is string typedMessageSource
&& typedMessageSource == this._sourceName)
{
return message;
}

message = message.Clone();
message.AdditionalProperties ??= new();
message.AdditionalProperties[AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.AIContextProvider;
message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = this._sourceName;
return message;
}).ToList();

return aiContext;
}

/// <summary>
/// Called at the start of agent invocation to provide additional context.
/// </summary>
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous operation. The task result contains the <see cref="AIContext"/> with additional context to be used by the agent during this invocation.</returns>
/// <remarks>
/// <para>
/// Implementers can load any additional context required at this time, such as:
/// <list type="bullet">
/// <item><description>Retrieving relevant information from knowledge bases</description></item>
/// <item><description>Adding system instructions or prompts</description></item>
/// <item><description>Providing function tools for the current invocation</description></item>
/// <item><description>Injecting contextual messages from conversation history</description></item>
/// </list>
/// </para>
/// </remarks>
protected abstract ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default);

/// <summary>
/// Called at the end of the agent invocation to process the invocation results.
/// </summary>
/// <param name="context">Contains the invocation context including request messages, response messages, and any exception that occurred.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous operation.</returns>
/// <remarks>
/// <para>
/// Implementers can use the request and response messages in the provided <paramref name="context"/> to:
/// <list type="bullet">
/// <item><description>Update internal state based on conversation outcomes</description></item>
/// <item><description>Extract and store memories or preferences from user messages</description></item>
/// <item><description>Log or audit conversation details</description></item>
/// <item><description>Perform cleanup or finalization tasks</description></item>
/// </list>
/// </para>
/// <para>
/// This method is called regardless of whether the invocation succeeded or failed.
/// To check if the invocation was successful, inspect the <see cref="InvokedContext.InvokeException"/> property.
/// </para>
/// </remarks>
public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
=> this.InvokedCoreAsync(context, cancellationToken);

/// <summary>
/// Called at the end of the agent invocation to process the invocation results.
Expand All @@ -71,7 +165,7 @@ public abstract class AIContextProvider
/// To check if the invocation was successful, inspect the <see cref="InvokedContext.InvokeException"/> property.
/// </para>
/// </remarks>
public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
protected virtual ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
=> default;

/// <summary>
Expand Down Expand Up @@ -117,7 +211,7 @@ public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOption
=> this.GetService(typeof(TService), serviceKey) is TService service ? service : default;

/// <summary>
/// Contains the context information provided to <see cref="InvokingAsync(InvokingContext, CancellationToken)"/>.
/// Contains the context information provided to <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>.
/// </summary>
/// <remarks>
/// This class provides context about the invocation before the underlying AI model is invoked, including the messages
Expand Down Expand Up @@ -163,7 +257,7 @@ public InvokingContext(
}

/// <summary>
/// Contains the context information provided to <see cref="InvokedAsync(InvokedContext, CancellationToken)"/>.
/// Contains the context information provided to <see cref="InvokedCoreAsync(InvokedContext, CancellationToken)"/>.
/// </summary>
/// <remarks>
/// This class provides context about a completed agent invocation, including both the
Expand All @@ -178,18 +272,15 @@ public sealed class InvokedContext
/// <param name="agent">The agent being invoked.</param>
/// <param name="session">The session associated with the agent invocation.</param>
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
/// <param name="aiContextProviderMessages">The messages provided by the <see cref="AIContextProvider"/> for this invocation, if any.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokedContext(
AIAgent agent,
AgentSession? session,
IEnumerable<ChatMessage> requestMessages,
IEnumerable<ChatMessage>? aiContextProviderMessages)
IEnumerable<ChatMessage> requestMessages)
{
this.Agent = Throw.IfNull(agent);
this.Session = session;
this.RequestMessages = Throw.IfNull(requestMessages);
this.AIContextProviderMessages = aiContextProviderMessages;
}

/// <summary>
Expand All @@ -211,15 +302,6 @@ public InvokedContext(
/// </value>
public IEnumerable<ChatMessage> RequestMessages { get; set { field = Throw.IfNull(value); } }

/// <summary>
/// Gets the messages provided by the <see cref="AIContextProvider"/> for this invocation, if any.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances that were provided by the <see cref="AIContextProvider"/>,
/// and were used by the agent as part of the invocation.
/// </value>
public IEnumerable<ChatMessage>? AIContextProviderMessages { get; set; }

/// <summary>
/// Gets the collection of response messages generated during this invocation if the invocation succeeded.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI;

/// <summary>
/// Provides a constant for the key used to store the source of the agent request message.
/// </summary>
public static class AgentRequestMessageSource
{
/// <summary>
/// Provides the key used in <see cref="ChatMessage.AdditionalProperties"/> to store the source of the agent request message.
/// </summary>
public static readonly string AdditionalPropertiesKey = "Agent.RequestMessageSource";
}
Loading
Loading