Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
<ItemGroup>
<PackageVersion Include="Microsoft.Extensions.AI" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="Microsoft.Extensions.Caching.Abstractions" Version="$(System10Version)" />
<PackageVersion Include="Microsoft.Extensions.Caching.Memory" Version="$(System10Version)" />
<PackageVersion Include="Microsoft.Extensions.Hosting.Abstractions" Version="$(System10Version)" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="$(System10Version)" />
</ItemGroup>
Expand Down
5 changes: 5 additions & 0 deletions src/ModelContextProtocol.Core/McpJsonUtilities.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.AI;
using ModelContextProtocol.Authentication;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -158,6 +159,10 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
[JsonSerializable(typeof(BlobResourceContents))]
[JsonSerializable(typeof(TextResourceContents))]

// Distributed cache event stream store
[JsonSerializable(typeof(DistributedCacheEventStreamStore.StreamMetadata))]
[JsonSerializable(typeof(DistributedCacheEventStreamStore.StoredEvent))]

// Other MCP Types
[JsonSerializable(typeof(IReadOnlyDictionary<string, object>))]
[JsonSerializable(typeof(ProgressToken))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" />
</ItemGroup>

<!-- Reference analyzers -->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

// This is a shared source file included in both ModelContextProtocol.Core and the test project.
// Do not reference symbols internal to the core project, as they won't be available in tests.

using System.Text;

namespace ModelContextProtocol.Server;

/// <summary>
/// Provides methods for formatting and parsing event IDs used by <see cref="DistributedCacheEventStreamStore"/>.
/// </summary>
/// <remarks>
/// Event IDs are formatted as "{base64(sessionId)}:{base64(streamId)}:{sequence}".
/// </remarks>
internal static class DistributedCacheEventIdFormatter
{
private const char Separator = ':';

/// <summary>
/// Formats session ID, stream ID, and sequence number into an event ID string.
/// </summary>
public static string Format(string sessionId, string streamId, long sequence)
{
// Base64-encode session and stream IDs so the event ID can be parsed
// even if the original IDs contain the ':' separator character
var sessionBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(sessionId));
var streamBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(streamId));
return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}";
}

/// <summary>
/// Attempts to parse an event ID into its component parts.
/// </summary>
public static bool TryParse(string eventId, out string sessionId, out string streamId, out long sequence)
{
sessionId = string.Empty;
streamId = string.Empty;
sequence = 0;

var parts = eventId.Split(Separator);
if (parts.Length != 3)
{
return false;
}

try
{
sessionId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[0]));
streamId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[1]));
return long.TryParse(parts[2], out sequence);
}
catch
{
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
using Microsoft.Extensions.Caching.Distributed;
using ModelContextProtocol.Protocol;
using System.Net.ServerSentEvents;
using System.Runtime.CompilerServices;
using System.Text.Json;

namespace ModelContextProtocol.Server;

/// <summary>
/// An <see cref="ISseEventStreamStore"/> implementation backed by <see cref="IDistributedCache"/>.
/// </summary>
/// <remarks>
/// <para>
/// This implementation stores SSE events in a distributed cache, enabling resumability across
/// multiple server instances. Event IDs are encoded with session, stream, and sequence information
/// to allow efficient retrieval of events after a given point.
/// </para>
/// <para>
/// The writer maintains in-memory state for sequence number generation, as there is guaranteed
/// to be only one writer per stream. Readers may be created from separate processes.
/// </para>
/// </remarks>
public sealed class DistributedCacheEventStreamStore : ISseEventStreamStore
{
private readonly IDistributedCache _cache;
private readonly DistributedCacheEventStreamStoreOptions _options;

/// <summary>
/// Initializes a new instance of the <see cref="DistributedCacheEventStreamStore"/> class.
/// </summary>
/// <param name="cache">The distributed cache to use for storage.</param>
/// <param name="options">Optional configuration options for the store.</param>
public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null)
{
Throw.IfNull(cache);
_cache = cache;
_options = options ?? new();
}

/// <inheritdoc />
public ValueTask<ISseEventStreamWriter> CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default)
{
Throw.IfNull(options);
var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options);
return new ValueTask<ISseEventStreamWriter>(writer);
}

/// <inheritdoc />
public async ValueTask<ISseEventStreamReader?> GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default)
{
Throw.IfNull(lastEventId);

// Parse the event ID to get session, stream, and sequence information
if (!DistributedCacheEventIdFormatter.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence))
{
return null;
}

// Check if the stream exists by looking for its metadata
var metadataKey = CacheKeys.StreamMetadata(sessionId, streamId);
var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false);
if (metadataBytes is null)
{
return null;
}

var metadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata);
if (metadata is null)
{
return null;
}

var startSequence = sequence + 1;
return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options);
}

/// <summary>
/// Provides methods for generating cache keys.
/// </summary>
internal static class CacheKeys
{
private const string Prefix = "mcp:sse:";

public static string StreamMetadata(string sessionId, string streamId) =>
$"{Prefix}meta:{sessionId}:{streamId}";

public static string Event(string eventId) =>
$"{Prefix}event:{eventId}";

public static string StreamEventCount(string sessionId, string streamId) =>
$"{Prefix}count:{sessionId}:{streamId}";
}

/// <summary>
/// Metadata about a stream stored in the cache.
/// </summary>
internal sealed class StreamMetadata
{
public SseEventStreamMode Mode { get; set; }
public bool IsCompleted { get; set; }
public long LastSequence { get; set; }
}

/// <summary>
/// Serialized representation of an SSE event stored in the cache.
/// </summary>
internal sealed class StoredEvent
{
public string? EventType { get; set; }
public string? EventId { get; set; }
public JsonRpcMessage? Data { get; set; }
}

private sealed class DistributedCacheEventStreamWriter : ISseEventStreamWriter
{
private readonly IDistributedCache _cache;
private readonly string _sessionId;
private readonly string _streamId;
private SseEventStreamMode _mode;
private readonly DistributedCacheEventStreamStoreOptions _options;
private long _sequence;
private bool _disposed;

public DistributedCacheEventStreamWriter(
IDistributedCache cache,
string sessionId,
string streamId,
SseEventStreamMode mode,
DistributedCacheEventStreamStoreOptions options)
{
_cache = cache;
_sessionId = sessionId;
_streamId = streamId;
_mode = mode;
_options = options;
}

public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default)
{
_mode = mode;
await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false);
}

public async ValueTask<SseItem<JsonRpcMessage?>> WriteEventAsync(SseItem<JsonRpcMessage?> sseItem, CancellationToken cancellationToken = default)
{
// Skip if already has an event ID
if (sseItem.EventId is not null)
{
return sseItem;
}

// Generate a new sequence number and event ID
var sequence = Interlocked.Increment(ref _sequence);
var eventId = DistributedCacheEventIdFormatter.Format(_sessionId, _streamId, sequence);
var newItem = sseItem with { EventId = eventId };

// Store the event in the cache
var storedEvent = new StoredEvent
{
EventType = newItem.EventType,
EventId = eventId,
Data = newItem.Data,
};

var eventBytes = JsonSerializer.SerializeToUtf8Bytes(storedEvent, McpJsonUtilities.JsonContext.Default.StoredEvent);
var eventKey = CacheKeys.Event(eventId);

await _cache.SetAsync(eventKey, eventBytes, new DistributedCacheEntryOptions
{
SlidingExpiration = _options.EventSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.EventAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);

// Update metadata with the latest sequence
await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false);

return newItem;
}

private async ValueTask UpdateMetadataAsync(CancellationToken cancellationToken)
{
var metadata = new StreamMetadata
{
Mode = _mode,
IsCompleted = _disposed,
LastSequence = Interlocked.Read(ref _sequence),
};

var metadataBytes = JsonSerializer.SerializeToUtf8Bytes(metadata, McpJsonUtilities.JsonContext.Default.StreamMetadata);
var metadataKey = CacheKeys.StreamMetadata(_sessionId, _streamId);

await _cache.SetAsync(metadataKey, metadataBytes, new DistributedCacheEntryOptions
{
SlidingExpiration = _options.MetadataSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);
}

public async ValueTask DisposeAsync()
{
if (_disposed)
{
return;
}

_disposed = true;

// Mark the stream as completed in the metadata
await UpdateMetadataAsync(CancellationToken.None).ConfigureAwait(false);
}
}

private sealed class DistributedCacheEventStreamReader : ISseEventStreamReader
{
private readonly IDistributedCache _cache;
private readonly long _startSequence;
private readonly StreamMetadata _initialMetadata;
private readonly DistributedCacheEventStreamStoreOptions _options;

public DistributedCacheEventStreamReader(
IDistributedCache cache,
string sessionId,
string streamId,
long startSequence,
StreamMetadata initialMetadata,
DistributedCacheEventStreamStoreOptions options)
{
_cache = cache;
SessionId = sessionId;
StreamId = streamId;
_startSequence = startSequence;
_initialMetadata = initialMetadata;
_options = options;
}

public string SessionId { get; }
public string StreamId { get; }

public async IAsyncEnumerable<SseItem<JsonRpcMessage?>> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Start from the sequence after the last received event
var currentSequence = _startSequence;

// Use the initial metadata passed to the constructor for the first read.
var lastSequence = _initialMetadata.LastSequence;
var isCompleted = _initialMetadata.IsCompleted;
var mode = _initialMetadata.Mode;

while (!cancellationToken.IsCancellationRequested)
{
// Read all available events from currentSequence + 1 to lastSequence
for (; currentSequence <= lastSequence; currentSequence++)
{
cancellationToken.ThrowIfCancellationRequested();

var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence);
var eventKey = CacheKeys.Event(eventId);
var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false)
?? throw new McpException($"SSE event with ID '{eventId}' was not found in the cache. The event may have expired.");

var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent);
if (storedEvent is not null)
{
yield return new SseItem<JsonRpcMessage?>(storedEvent.Data, storedEvent.EventType)
{
EventId = storedEvent.EventId,
};
}
}

// If in polling mode, stop after returning currently available events
if (mode == SseEventStreamMode.Polling)
{
yield break;
}

// If the stream is completed and we've read all events, stop
if (isCompleted)
{
yield break;
}

// Wait before polling again for new events
await Task.Delay(_options.PollingInterval, cancellationToken).ConfigureAwait(false);

// Refresh metadata to get the latest sequence and completion status
var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId);
var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false)
?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' was not found in the cache. The metadata may have expired.");

var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata)
?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' could not be deserialized.");

lastSequence = currentMetadata.LastSequence;
isCompleted = currentMetadata.IsCompleted;
mode = currentMetadata.Mode;
}
}
}
}
Loading
Loading