Skip to content
Merged
6 changes: 5 additions & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ export class Client<

override async connect(transport: Transport, options?: RequestOptions): Promise<void> {
await super.connect(transport);

// When transport sessionId is already set this means we are trying to reconnect.
// In this case we don't need to initialize again.
if (transport.sessionId !== undefined) {
return;
}
try {
const result = await this.request(
{
Expand Down
10 changes: 5 additions & 5 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => {
// We expect the 405 error to be caught and handled gracefully
// This should not throw an error that breaks the transport
await transport.start();
await expect(transport["_startOrAuthSse"]()).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed");
await expect(transport["_startOrAuthSse"]({})).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed");
// Check that GET was attempted
expect(global.fetch).toHaveBeenCalledWith(
expect.anything(),
Expand Down Expand Up @@ -208,7 +208,7 @@ describe("StreamableHTTPClientTransport", () => {
transport.onmessage = messageSpy;

await transport.start();
await transport["_startOrAuthSse"]();
await transport["_startOrAuthSse"]({});

// Give time for the SSE event to be processed
await new Promise(resolve => setTimeout(resolve, 50));
Expand Down Expand Up @@ -313,9 +313,9 @@ describe("StreamableHTTPClientTransport", () => {
await transport.start();
// Type assertion to access private method
const transportWithPrivateMethods = transport as unknown as {
_startOrAuthSse: (lastEventId?: string) => Promise<void>
_startOrAuthSse: (options: { lastEventId?: string }) => Promise<void>
};
await transportWithPrivateMethods._startOrAuthSse("test-event-id");
await transportWithPrivateMethods._startOrAuthSse({ lastEventId: "test-event-id" });

// Verify fetch was called with the lastEventId header
expect(fetchSpy).toHaveBeenCalled();
Expand Down Expand Up @@ -382,7 +382,7 @@ describe("StreamableHTTPClientTransport", () => {

await transport.start();

await transport["_startOrAuthSse"]();
await transport["_startOrAuthSse"]({});
expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue");

requestInit.headers["X-Custom-Header"] = "SecondCustomValue";
Expand Down
47 changes: 32 additions & 15 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Transport } from "../shared/transport.js";
import { isJSONRPCNotification, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { EventSourceParserStream } from "eventsource-parser/stream";

Expand Down Expand Up @@ -28,6 +28,14 @@ export interface StartSSEOptions {
* The ID of the last received event, used for resuming a disconnected stream
*/
lastEventId?: string;
/**
* The callback function that is invoked when the last event ID changes
*/
onLastEventIdUpdate?: (event: string) => void
/**
* When reconnecting to a long-running SSE stream, we need to make sure that message id matches
*/
replayMessageId?: string | number;
}

/**
Expand Down Expand Up @@ -88,6 +96,7 @@ export type StreamableHTTPClientTransportOptions = {
* Options to configure the reconnection behavior.
*/
reconnectionOptions?: StreamableHTTPReconnectionOptions;

/**
* Session ID for the connection. This is used to identify the session on the server.
* When not provided and connecting to a server that supports session IDs, the server will generate a new session ID.
Expand Down Expand Up @@ -119,8 +128,8 @@ export class StreamableHTTPClientTransport implements Transport {
this._url = url;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._reconnectionOptions = opts?.reconnectionOptions || this._defaultReconnectionOptions;
this._sessionId = opts?.sessionId;
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
}

private async _authThenStart(): Promise<void> {
Expand All @@ -140,7 +149,7 @@ export class StreamableHTTPClientTransport implements Transport {
throw new UnauthorizedError();
}

return await this._startOrAuthSse();
return await this._startOrAuthSse({ lastEventId: undefined });
}

private async _commonHeaders(): Promise<Headers> {
Expand All @@ -161,7 +170,9 @@ export class StreamableHTTPClientTransport implements Transport {
);
}

private async _startOrAuthSse(lastEventId?: string): Promise<void> {

private async _startOrAuthSse(options: StartSSEOptions): Promise<void> {
const { lastEventId } = options;
try {
// Try to open an initial SSE stream with GET to listen for server messages
// This is optional according to the spec - server may not support it
Expand Down Expand Up @@ -197,7 +208,7 @@ export class StreamableHTTPClientTransport implements Transport {
);
}

this._handleSseStream(response.body);
this._handleSseStream(response.body, options);
} catch (error) {
this.onerror?.(error as Error);
throw error;
Expand Down Expand Up @@ -228,7 +239,7 @@ export class StreamableHTTPClientTransport implements Transport {
* @param lastEventId The ID of the last received event for resumability
* @param attemptCount Current reconnection attempt count for this specific stream
*/
private _scheduleReconnection(lastEventId: string, attemptCount = 0): void {
private _scheduleReconnection(options: StartSSEOptions, attemptCount = 0): void {
// Use provided options or default options
const maxRetries = this._reconnectionOptions.maxRetries;

Expand All @@ -244,18 +255,19 @@ export class StreamableHTTPClientTransport implements Transport {
// Schedule the reconnection
setTimeout(() => {
// Use the last event ID to resume where we left off
this._startOrAuthSse(lastEventId).catch(error => {
this._startOrAuthSse(options).catch(error => {
this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`));
// Schedule another attempt if this one failed, incrementing the attempt counter
this._scheduleReconnection(lastEventId, attemptCount + 1);
this._scheduleReconnection(options, attemptCount + 1);
});
}, delay);
}

private _handleSseStream(stream: ReadableStream<Uint8Array> | null, onLastEventIdUpdate?: (event: string) => void): void {
private _handleSseStream(stream: ReadableStream<Uint8Array> | null, options: StartSSEOptions): void {
if (!stream) {
return;
}
const { onLastEventIdUpdate, replayMessageId } = options;

let lastEventId: string | undefined;
const processStream = async () => {
Expand Down Expand Up @@ -284,6 +296,9 @@ export class StreamableHTTPClientTransport implements Transport {
if (!event.event || event.event === "message") {
try {
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
if (replayMessageId !== undefined && isJSONRPCResponse(message)) {
message.id = replayMessageId;
}
this.onmessage?.(message);
} catch (error) {
this.onerror?.(error as Error);
Expand All @@ -299,7 +314,7 @@ export class StreamableHTTPClientTransport implements Transport {
// Use the exponential backoff reconnection strategy
if (lastEventId !== undefined) {
try {
this._scheduleReconnection(lastEventId, 0);
this._scheduleReconnection(options, 0);
}
catch (error) {
this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`));
Expand Down Expand Up @@ -343,13 +358,15 @@ export class StreamableHTTPClientTransport implements Transport {
this.onclose?.();
}

async send(message: JSONRPCMessage | JSONRPCMessage[], options?: { lastEventId?: string, onLastEventIdUpdate?: (event: string) => void }): Promise<void> {
async send(message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string, onresumptiontoken?: (event: string) => void }): Promise<void> {
try {
// If client passes in a lastEventId in the request options, we need to reconnect the SSE stream
const { lastEventId, onLastEventIdUpdate } = options ?? {};
const lastEventId = options?.resumptionToken
const onLastEventIdUpdate = options?.onresumptiontoken;
if (lastEventId) {

// If we have at last event ID, we need to reconnect the SSE stream
this._startOrAuthSse(lastEventId).catch(err => this.onerror?.(err));
this._startOrAuthSse({ lastEventId, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch(err => this.onerror?.(err));
return;
}

Expand Down Expand Up @@ -396,7 +413,7 @@ export class StreamableHTTPClientTransport implements Transport {
// if it's supported by the server
if (isJSONRPCNotification(message) && message.method === "notifications/initialized") {
// Start without a lastEventId since this is a fresh connection
this._startOrAuthSse().catch(err => this.onerror?.(err));
this._startOrAuthSse({ lastEventId: undefined }).catch(err => this.onerror?.(err));
}
return;
}
Expand All @@ -414,7 +431,7 @@ export class StreamableHTTPClientTransport implements Transport {
// Handle SSE stream responses for requests
// We use the same handler as standalone streams, which now supports
// reconnection with the last event ID
this._handleSseStream(response.body, onLastEventIdUpdate);
this._handleSseStream(response.body, { onLastEventIdUpdate });
} else if (contentType?.includes("application/json")) {
// For non-streaming servers, we might get direct JSON responses
const data = await response.json();
Expand Down
4 changes: 2 additions & 2 deletions src/examples/client/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ function commandLoop(): void {

case 'start-notifications': {
const interval = args[1] ? parseInt(args[1], 10) : 2000;
const count = args[2] ? parseInt(args[2], 10) : 0;
const count = args[2] ? parseInt(args[2], 10) : 10;
await startNotifications(interval, count);
break;
}
Expand Down Expand Up @@ -302,7 +302,7 @@ async function callTool(name: string, args: Record<string, unknown>): Promise<vo
notificationsToolLastEventId = event;
};
const result = await client.request(request, CallToolResultSchema, {
lastEventId: notificationsToolLastEventId, onLastEventIdUpdate
resumptionToken: notificationsToolLastEventId, onresumptiontoken: onLastEventIdUpdate
});

console.log('Tool result:');
Expand Down
56 changes: 30 additions & 26 deletions src/integration-tests/taskResumability.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ import { z } from 'zod';
class InMemoryEventStore implements EventStore {
private events: Map<string, { streamId: string, message: JSONRPCMessage }> = new Map();

generateEventId(streamId: string): string {
private generateEventId(streamId: string): string {
return `${streamId}_${Date.now()}_${Math.random().toString(36).substring(2, 10)}`;
}

getStreamIdFromEventId(eventId: string): string {
private getStreamIdFromEventId(eventId: string): string {
const parts = eventId.split('_');
return parts.length > 0 ? parts[0] : '';
}
Expand All @@ -29,14 +29,18 @@ class InMemoryEventStore implements EventStore {
return eventId;
}

async getEventsAfter(lastEventId: string): Promise<Array<{ eventId: string, message: JSONRPCMessage }>> {
async replayEventsAfter(lastEventId: string,
{ send }: { send: (eventId: string, message: JSONRPCMessage) => Promise<void> }
): Promise<string> {
if (!lastEventId || !this.events.has(lastEventId)) {
return [];
return '';
}

// Extract the stream ID from the event ID
const streamId = this.getStreamIdFromEventId(lastEventId);
const result: Array<{ eventId: string, message: JSONRPCMessage }> = [];
if (!streamId) {
return '';
}
let foundLastEvent = false;

// Sort events by eventId for chronological ordering
Expand All @@ -55,11 +59,11 @@ class InMemoryEventStore implements EventStore {
}

if (foundLastEvent) {
result.push({ eventId, message });
await send(eventId, message);
}
}

return result;
return streamId;
}
}

Expand Down Expand Up @@ -238,43 +242,46 @@ describe('Transport resumability', () => {
params: {
name: 'run-notifications',
arguments: {
count: 5,
count: 3,
interval: 10
}
}
}, CallToolResultSchema, {
lastEventId,
onLastEventIdUpdate
resumptionToken: lastEventId,
onresumptiontoken: onLastEventIdUpdate
});

// Wait for some notifications to arrive (not all)
// Wait for some notifications to arrive (not all) - shorter wait time
await new Promise(resolve => setTimeout(resolve, 20));

// Verify we received some notifications and lastEventId was updated
expect(notifications.length).toBeGreaterThan(0);
expect(notifications.length).toBeLessThan(5);
expect(notifications.length).toBeLessThan(4);
expect(onLastEventIdUpdate).toHaveBeenCalled();
expect(lastEventId).toBeDefined();

// Store original notification count for later comparison
const firstClientNotificationCount = notifications.length;

// Disconnect first client without waiting for completion
// When we close the connection, it will cause a ConnectionClosed error for
// any in-progress requests, which is expected behavior
// We need to catch the error since closing the transport will
// cause the pending toolPromise to reject with a ConnectionClosed error
await transport1.close();

// Try to cancel the promise, but ignore errors since it's already being handled
toolPromise.catch(err => {
// Save the promise so we can catch it after closing
const catchPromise = toolPromise.catch(err => {
// This error is expected - the connection was intentionally closed
if (err?.code !== -32000) { // ConnectionClosed error code
console.error("Unexpected error type during transport close:", err);
}
});



// Add a short delay to ensure clean disconnect before reconnecting
await new Promise(resolve => setTimeout(resolve, 10));

// Wait for the rejection to be handled
await catchPromise;


// Create second client with same client ID
const client2 = new Client({
id: clientId,
Expand Down Expand Up @@ -303,22 +310,19 @@ describe('Transport resumability', () => {
name: 'run-notifications',
arguments: {
count: 1,
interval: 50
interval: 5
}
}
}, CallToolResultSchema, {
lastEventId, // Pass the lastEventId from the previous session
onLastEventIdUpdate
resumptionToken: lastEventId, // Pass the lastEventId from the previous session
onresumptiontoken: onLastEventIdUpdate
});

// Verify we eventually received at leaset a few motifications
expect(notifications.length).toBeGreaterThan(2);
expect(notifications.length).toBeGreaterThan(1);

// Verify the second client received notifications that the first client didn't
expect(notifications.length).toBeGreaterThan(firstClientNotificationCount);

// Clean up

await transport2.close();

});
Expand Down
5 changes: 2 additions & 3 deletions src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,10 @@ export class StreamableHTTPServerTransport implements Transport {
const isInitializationRequest = messages.some(
msg => 'method' in msg && msg.method === 'initialize'
);
const mcpSessionId = req.headers["mcp-session-id"] as string | undefined;
if (isInitializationRequest) {
// If it's a server with session management and the session ID is already set we should reject the request
// to avoid re-initialization.
if (this._initialized && this.sessionId !== undefined && mcpSessionId !== this.sessionId) {
if (this._initialized) {
res.writeHead(400).end(JSON.stringify({
jsonrpc: "2.0",
error: {
Expand All @@ -357,7 +356,7 @@ export class StreamableHTTPServerTransport implements Transport {
}));
return;
}
this.sessionId = mcpSessionId ?? this.sessionIdGenerator();
this.sessionId = this.sessionIdGenerator();
this._initialized = true;

}
Expand Down
Loading