Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions src/ModelContextProtocol/McpChatClientBuilderExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Client;
#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

namespace ModelContextProtocol;

/// <summary>
/// Extension methods for adding MCP client support to chat clients.
/// </summary>
public static class McpChatClientBuilderExtensions
{
/// <summary>
/// Adds a chat client to the chat client pipeline that creates an <see cref="McpClient"/> for each <see cref="HostedMcpServerTool"/>
/// in <see cref="ChatOptions.Tools"/> and augments it with the tools from MCP servers as <see cref="AIFunction"/> instances.
/// </summary>
/// <param name="builder">The <see cref="ChatClientBuilder"/> to configure.</param>
/// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to create a new instance.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use, or <see langword="null"/> to resolve from services.</param>
/// <returns>The <see cref="ChatClientBuilder"/> for method chaining.</returns>
/// <remarks>
/// <para>
/// When a <c>HostedMcpServerTool</c> is encountered in the tools collection, the client
/// connects to the MCP server, retrieves available tools, and expands them into callable AI functions.
/// Connections are cached by server address to avoid redundant connections.
/// </para>
/// <para>
/// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers.
/// </para>
/// </remarks>
public static ChatClientBuilder UseMcpClient(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should mark the public surface are for this as [Experimental]. We could do so either with "MEAI001", so that it matches with the MCP types on which its based (and in which case you wouldn't need the earlier suppression) or have some new "MCP003" or something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to align this with the ID we end up using for McpServer members. cc @jeffhandley.

this ChatClientBuilder builder,
HttpClient? httpClient = null,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting to consider what should be accepted here and whether that has any impact on the shape of the options we expose in the MCP library. We can't just accept an HttpClientTransport, as that's tied to a specific endpoint, but otherwise we effectively would want most of the options exposed on that, I think.

Copy link
Contributor

@halter73 halter73 Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. You wouldn't want to share the ITokenCache, but the remaining ClientOAuthOptions could make a lot of sense. Maybe we should move the TokenCache out of ClientOAuthOptions so we could flow the rest of it through.

Headers are a bit in-between. I could see wanting to be able to configure a mapping for known MCP server URL prefixes for things like API keys, but that becomes a lot more of a complex API. At that point maybe we'd recommend putting custom logic in a DelegatingHandler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a Func<HostedMcpServerTool, HttpClientTransportOptions?>? transportOptionsProvider = null? Someone could return options for specific servers, their own defaults, or null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe Action<HostedMcpServerTool, HttpClientTransportOptions>

ILoggerFactory? loggerFactory = null)
{
return builder.Use((innerClient, services) =>
{
loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!;
var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory);
return chatClient;
});
}

private class McpChatClient : DelegatingChatClient
{
private readonly ILoggerFactory? _loggerFactory;
private readonly ILogger _logger;
private readonly HttpClient _httpClient;
private readonly bool _ownsHttpClient;
private ConcurrentDictionary<string, Task<McpClient>>? _mcpClientTasks = null;

/// <summary>
/// Initializes a new instance of the <see cref="McpChatClient"/> class.
/// </summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
/// <param name="httpClient">An optional <see cref="HttpClient"/> to use when connecting to MCP servers. If not provided, a new instance will be created.</param>
/// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param>
public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null)
: base(innerClient)
{
_loggerFactory = loggerFactory;
_logger = (ILogger?)loggerFactory?.CreateLogger<McpChatClient>() ?? NullLogger.Instance;
_httpClient = httpClient ?? new HttpClient();
_ownsHttpClient = httpClient is null;
}

/// <inheritdoc/>
public override async Task<ChatResponse> GetResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
if (options?.Tools is { Count: > 0 })
{
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false);
options = options.Clone();
options.Tools = downstreamTools;
}

return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (options?.Tools is { Count: > 0 })
{
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false);
options = options.Clone();
options.Tools = downstreamTools;
}

await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
{
yield return update;
}
}

private async Task<List<AITool>?> BuildDownstreamAIToolsAsync(IList<AITool>? inputTools, CancellationToken cancellationToken)
{
List<AITool>? downstreamTools = null;
foreach (var tool in inputTools ?? [])
{
if (tool is not HostedMcpServerTool mcpTool)
{
// For other tools, we want to keep them in the list of tools.
downstreamTools ??= new List<AITool>();
downstreamTools.Add(tool);
continue;
}

if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) ||
(parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps))
{
throw new InvalidOperationException(
$"MCP server address must be an absolute HTTP or HTTPS URI. Invalid address: '{mcpTool.ServerAddress}'");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do OpenAI and Anthropic similarly fail if provided with an invalid URI?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I see a 424 and a 500 response from OpenAI and Anthropic respectively.

The primary intent here is to let users know that connectors won't be supported with this middleware. I don't think that's something we should simply filter-out, thoughts?

}

// List all MCP functions from the specified MCP server.
// This will need some caching in a real-world scenario to avoid repeated calls.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you thinking we wouldn't merge this and this PR is just for exploration? I'm ok with that, just trying to understand your intentions. Though I think there could be value in actually shipping this (with appropriate caching).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an outdated comment, it should be addressed with the ConcurrentDictionary caching the mcp clients. To be clear, I did not write this code from scratch, thankfully, @westey-m provided it as a POC as part of microsoft/agent-framework#209.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should try to ship it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's caching the client instances but not the tools, which means it's going to make a list tools request to the server on each call. It's also never removing things from the cache.

var mcpClient = await CreateMcpClientAsync(parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false);
var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);

// Add the listed functions to our list of tools we'll pass to the inner client.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this approach account for the list of tools potentially changing after this initialization step?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't. We are yet to define lifespan of listed tools and how to handle protocol errors on invoked tools e.g. tool no longer exists.

foreach (var mcpFunction in mcpFunctions)
{
if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name))
{
_logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name);
continue;
}

downstreamTools ??= new List<AITool>();
switch (mcpTool.ApprovalMode)
{
case HostedMcpServerToolAlwaysRequireApprovalMode alwaysRequireApproval:
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
break;
case HostedMcpServerToolNeverRequireApprovalMode neverRequireApproval:
downstreamTools.Add(mcpFunction);
break;
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.AlwaysRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
break;
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
downstreamTools.Add(mcpFunction);
break;
default:
// Default to always require approval if no specific mode is set.
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
break;
}
}
}

return downstreamTools;
}

/// <inheritdoc/>
protected override void Dispose(bool disposing)
{
if (disposing)
{
// Dispose of the HTTP client if it was created by this client.
if (_ownsHttpClient)
{
_httpClient?.Dispose();
}

if (_mcpClientTasks is not null)
{
// Dispose of all cached MCP clients.
foreach (var clientTask in _mcpClientTasks.Values)
{
#if NETSTANDARD2_0
if (clientTask.Status == TaskStatus.RanToCompletion)
#else
if (clientTask.IsCompletedSuccessfully)
#endif
{
_ = clientTask.Result.DisposeAsync();
}
}

_mcpClientTasks.Clear();
}
}

base.Dispose(disposing);
}

private Task<McpClient> CreateMcpClientAsync(Uri serverAddress, string serverName, string? authorizationToken)
{
if (_mcpClientTasks is null)
{
_mcpClientTasks = new ConcurrentDictionary<string, Task<McpClient>>(StringComparer.OrdinalIgnoreCase);
}

// Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token.
// Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently.
return _mcpClientTasks.GetOrAdd(serverAddress.ToString(), _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None));
}

private async Task<McpClient> CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken)
{
var serverAddressKey = serverAddress.ToString();
try
{
var transport = new HttpClientTransport(new HttpClientTransportOptions
{
Endpoint = serverAddress,
Name = serverName,
AdditionalHeaders = authorizationToken is not null
// Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available.
? new Dictionary<string, string>() { { "Authorization", $"Bearer {authorizationToken}" } }
: null,
}, _httpClient, _loggerFactory);

return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch
{
// Remove the failed task from cache so subsequent requests can retry
_mcpClientTasks?.TryRemove(serverAddressKey, out _);
throw;
}
}
}
}
1 change: 1 addition & 0 deletions src/ModelContextProtocol/ModelContextProtocol.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" />
<PackageReference Include="Microsoft.Extensions.AI" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

namespace ModelContextProtocol.AspNetCore.Tests;

public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture<SseServerIntegrationTestFixture>
public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture<SseServerWithXunitLoggerFixture>
{
protected readonly SseServerIntegrationTestFixture _fixture;
protected readonly SseServerWithXunitLoggerFixture _fixture;

public HttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
public HttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
: base(testOutputHelper)
{
_fixture = fixture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,18 @@

namespace ModelContextProtocol.AspNetCore.Tests;

public class SseServerIntegrationTestFixture : IAsyncDisposable
public abstract class SseServerIntegrationTestFixture : IAsyncDisposable
{
private readonly KestrelInMemoryTransport _inMemoryTransport = new();

private readonly Task _serverTask;
private readonly CancellationTokenSource _stopCts = new();

// XUnit's ITestOutputHelper is created per test, while this fixture is used for
// multiple tests, so this dispatches the output to the current test.
private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new();

private HttpClientTransportOptions DefaultTransportOptions { get; set; } = new()
{
Endpoint = new("http://localhost:5000/"),
};

public SseServerIntegrationTestFixture()
protected SseServerIntegrationTestFixture()
{
var socketsHttpHandler = new SocketsHttpHandler
{
Expand All @@ -39,8 +34,10 @@ public SseServerIntegrationTestFixture()
BaseAddress = new("http://localhost:5000/"),
};

_serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token);
_serverTask = Program.MainAsync([], CreateLoggerProvider(), _inMemoryTransport, _stopCts.Token);
}

protected abstract ILoggerProvider CreateLoggerProvider();

public HttpClient HttpClient { get; }

Expand All @@ -53,21 +50,17 @@ public Task<McpClient> ConnectMcpClientAsync(McpClientOptions? options, ILoggerF
TestContext.Current.CancellationToken);
}

public void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions)
public virtual void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions)
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = output;
DefaultTransportOptions = clientTransportOptions;
}

public void TestCompleted()
public virtual void TestCompleted()
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
}

public async ValueTask DisposeAsync()
public virtual async ValueTask DisposeAsync()
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;

HttpClient.Dispose();
_stopCts.Cancel();

Expand All @@ -82,3 +75,49 @@ public async ValueTask DisposeAsync()
_stopCts.Dispose();
}
}

/// <summary>
/// SSE server fixture that routes logs to xUnit test output.
/// </summary>
public class SseServerWithXunitLoggerFixture : SseServerIntegrationTestFixture
{
// XUnit's ITestOutputHelper is created per test, while this fixture is used for
// multiple tests, so this dispatches the output to the current test.
private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new();

protected override ILoggerProvider CreateLoggerProvider()
=> new XunitLoggerProvider(_delegatingTestOutputHelper);

public override void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions)
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = output;
base.Initialize(output, clientTransportOptions);
}

public override void TestCompleted()
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
base.TestCompleted();
}

public override async ValueTask DisposeAsync()
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
await base.DisposeAsync();
}
}

/// <summary>
/// Fixture for tests that need to inspect server logs using MockLoggerProvider.
/// Use <see cref="SseServerWithXunitLoggerFixture"/> for tests that just need xUnit output.
/// </summary>
public class SseServerWithMockLoggerFixture : SseServerIntegrationTestFixture
{
private readonly MockLoggerProvider _mockLoggerProvider = new();

protected override ILoggerProvider CreateLoggerProvider()
=> _mockLoggerProvider;

public IEnumerable<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> ServerLogs
=> _mockLoggerProvider.LogMessages;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace ModelContextProtocol.AspNetCore.Tests;

public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
public class SseServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
: HttpServerIntegrationTests(fixture, testOutputHelper)

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace ModelContextProtocol.AspNetCore.Tests;

public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
public class StatelessServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
: StreamableHttpServerIntegrationTests(fixture, testOutputHelper)
{
protected override HttpClientTransportOptions ClientTransportOptions => new()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace ModelContextProtocol.AspNetCore.Tests;

public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
public class StreamableHttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
: HttpServerIntegrationTests(fixture, testOutputHelper)

{
Expand Down
Loading
Loading