Skip to content

Commit 1e407ce

Browse files
committed
Client-side polling via Last-Event-ID
1 parent ae17ba0 commit 1e407ce

File tree

2 files changed

+164
-31
lines changed

2 files changed

+164
-31
lines changed

src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,17 @@ public required Uri Endpoint
106106
/// Gets sor sets the authorization provider to use for authentication.
107107
/// </summary>
108108
public ClientOAuthOptions? OAuth { get; set; }
109+
110+
/// <summary>
111+
/// Gets or sets the maximum number of reconnection attempts when an SSE stream is disconnected.
112+
/// </summary>
113+
/// <value>
114+
/// The maximum number of reconnection attempts. The default is 2.
115+
/// </value>
116+
/// <remarks>
117+
/// When an SSE stream is disconnected (e.g., due to a network issue), the client will attempt to
118+
/// reconnect using the Last-Event-ID header to resume from where it left off. This property controls
119+
/// how many reconnection attempts are made before giving up.
120+
/// </remarks>
121+
public int MaxReconnectionAttempts { get; set; } = 2;
109122
}

src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs

Lines changed: 151 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
1616
private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json");
1717
private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream");
1818

19+
private static readonly TimeSpan s_defaultReconnectionDelay = TimeSpan.FromSeconds(1);
20+
1921
private readonly McpHttpClient _httpClient;
2022
private readonly HttpClientTransportOptions _options;
2123
private readonly CancellationTokenSource _connectionCts;
@@ -105,8 +107,18 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
105107
}
106108
else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream")
107109
{
108-
using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken);
109-
rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false);
110+
var sseState = new SseStreamState();
111+
using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
112+
var sseResponse = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, sseState, cancellationToken).ConfigureAwait(false);
113+
rpcResponseOrError = sseResponse.Response;
114+
115+
// Resumability: If POST SSE stream ended without a response but we have a Last-Event-ID (from priming),
116+
// attempt to resume by sending a GET request with Last-Event-ID header. The server will replay
117+
// events from the event store, allowing us to receive the pending response.
118+
if (rpcResponseOrError is null && rpcRequest is not null && sseState.LastEventId is not null)
119+
{
120+
rpcResponseOrError = await SendGetSseRequestWithRetriesAsync(rpcRequest, sseState, cancellationToken).ConfigureAwait(false);
121+
}
110122
}
111123

112124
if (rpcRequest is null)
@@ -188,56 +200,140 @@ public override async ValueTask DisposeAsync()
188200

189201
private async Task ReceiveUnsolicitedMessagesAsync()
190202
{
191-
// Send a GET request to handle any unsolicited messages not sent over a POST response.
192-
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
193-
request.Headers.Accept.Add(s_textEventStreamMediaType);
194-
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion);
203+
var state = new SseStreamState();
195204

196-
// Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages.
197-
HttpResponseMessage response;
198-
try
205+
// Continuously receive unsolicited messages until canceled
206+
while (!_connectionCts.Token.IsCancellationRequested)
199207
{
200-
response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false);
201-
}
202-
catch (HttpRequestException)
203-
{
204-
return;
205-
}
208+
await SendGetSseRequestWithRetriesAsync(
209+
relatedRpcRequest: null,
210+
state,
211+
_connectionCts.Token).ConfigureAwait(false);
206212

207-
using (response)
208-
{
209-
if (!response.IsSuccessStatusCode)
213+
// If we exhausted retries without receiving any events, stop trying
214+
if (state.LastEventId is null)
210215
{
211216
return;
212217
}
213-
214-
using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false);
215-
await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false);
216218
}
217219
}
218220

219-
private async Task<JsonRpcMessageWithId?> ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
221+
/// <summary>
222+
/// Sends a GET request for SSE with retry logic and resumability support.
223+
/// </summary>
224+
private async Task<JsonRpcMessageWithId?> SendGetSseRequestWithRetriesAsync(
225+
JsonRpcRequest? relatedRpcRequest,
226+
SseStreamState state,
227+
CancellationToken cancellationToken)
220228
{
221-
await foreach (SseItem<string> sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
229+
int attempt = 0;
230+
231+
// Delay before first attempt if we're reconnecting (have a Last-Event-ID)
232+
bool shouldDelay = state.LastEventId is not null;
233+
234+
while (attempt < _options.MaxReconnectionAttempts)
222235
{
223-
if (sseEvent.EventType != "message")
236+
cancellationToken.ThrowIfCancellationRequested();
237+
238+
if (shouldDelay)
224239
{
225-
continue;
240+
var delay = state.RetryInterval ?? s_defaultReconnectionDelay;
241+
await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
226242
}
243+
shouldDelay = true;
227244

228-
var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);
245+
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
246+
request.Headers.Accept.Add(s_textEventStreamMediaType);
247+
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion, state.LastEventId);
229248

230-
// The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes
231-
// a GET request for any notifications that might need to be sent after the completion of each POST.
232-
if (rpcResponseOrError is not null)
249+
HttpResponseMessage response;
250+
try
233251
{
234-
return rpcResponseOrError;
252+
response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false);
253+
}
254+
catch (HttpRequestException)
255+
{
256+
attempt++;
257+
continue;
258+
}
259+
260+
using (response)
261+
{
262+
if (!response.IsSuccessStatusCode)
263+
{
264+
// If the server could be reached but returned a non-success status code,
265+
// retrying likely won't change that.
266+
return null;
267+
}
268+
269+
using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
270+
var sseResponse = await ProcessSseResponseAsync(responseStream, relatedRpcRequest, state, cancellationToken).ConfigureAwait(false);
271+
272+
if (sseResponse.Response is { } rpcResponseOrError)
273+
{
274+
return rpcResponseOrError;
275+
}
276+
277+
// If we reach here, then the stream closed without the response.
278+
279+
if (sseResponse.IsNetworkError || state.LastEventId is null)
280+
{
281+
// No event ID means server may not support resumability; don't retry indefinitely.
282+
attempt++;
283+
}
284+
else
285+
{
286+
// We have an event ID, so we continue polling to receive more events.
287+
// The server should eventually send a response or return an error.
288+
attempt = 0;
289+
}
235290
}
236291
}
237292

238293
return null;
239294
}
240295

296+
private async Task<SseResponse> ProcessSseResponseAsync(
297+
Stream responseStream,
298+
JsonRpcRequest? relatedRpcRequest,
299+
SseStreamState state,
300+
CancellationToken cancellationToken)
301+
{
302+
try
303+
{
304+
await foreach (SseItem<string> sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
305+
{
306+
// Track event ID and retry interval for resumability
307+
if (!string.IsNullOrEmpty(sseEvent.EventId))
308+
{
309+
state.LastEventId = sseEvent.EventId;
310+
}
311+
if (sseEvent.ReconnectionInterval.HasValue)
312+
{
313+
state.RetryInterval = sseEvent.ReconnectionInterval.Value;
314+
}
315+
316+
// Skip events with empty data
317+
if (string.IsNullOrEmpty(sseEvent.Data))
318+
{
319+
continue;
320+
}
321+
322+
var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);
323+
if (rpcResponseOrError is not null)
324+
{
325+
return new() { Response = rpcResponseOrError };
326+
}
327+
}
328+
}
329+
catch (Exception ex) when (ex is IOException or HttpRequestException)
330+
{
331+
return new() { IsNetworkError = true };
332+
}
333+
334+
return default;
335+
}
336+
241337
private async Task<JsonRpcMessageWithId?> ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
242338
{
243339
LogTransportReceivedMessageSensitive(Name, data);
@@ -292,7 +388,8 @@ internal static void CopyAdditionalHeaders(
292388
HttpRequestHeaders headers,
293389
IDictionary<string, string>? additionalHeaders,
294390
string? sessionId,
295-
string? protocolVersion)
391+
string? protocolVersion,
392+
string? lastEventId = null)
296393
{
297394
if (sessionId is not null)
298395
{
@@ -304,6 +401,11 @@ internal static void CopyAdditionalHeaders(
304401
headers.Add("MCP-Protocol-Version", protocolVersion);
305402
}
306403

404+
if (lastEventId is not null)
405+
{
406+
headers.Add("Last-Event-ID", lastEventId);
407+
}
408+
307409
if (additionalHeaders is null)
308410
{
309411
return;
@@ -317,4 +419,22 @@ internal static void CopyAdditionalHeaders(
317419
}
318420
}
319421
}
422+
423+
/// <summary>
424+
/// Tracks state across SSE stream connections.
425+
/// </summary>
426+
private sealed class SseStreamState
427+
{
428+
public string? LastEventId { get; set; }
429+
public TimeSpan? RetryInterval { get; set; }
430+
}
431+
432+
/// <summary>
433+
/// Represents the result of processing an SSE response.
434+
/// </summary>
435+
private readonly struct SseResponse
436+
{
437+
public JsonRpcMessageWithId? Response { get; init; }
438+
public bool IsNetworkError { get; init; }
439+
}
320440
}

0 commit comments

Comments
 (0)