Skip to content
Merged
10 changes: 5 additions & 5 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => {
// We expect the 405 error to be caught and handled gracefully
// This should not throw an error that breaks the transport
await transport.start();
await expect(transport["_startOrAuthStandaloneSSE"]()).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed");
await expect(transport["_startOrAuthSse"]()).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed");
// Check that GET was attempted
expect(global.fetch).toHaveBeenCalledWith(
expect.anything(),
Expand Down Expand Up @@ -208,7 +208,7 @@ describe("StreamableHTTPClientTransport", () => {
transport.onmessage = messageSpy;

await transport.start();
await transport["_startOrAuthStandaloneSSE"]();
await transport["_startOrAuthSse"]();

// Give time for the SSE event to be processed
await new Promise(resolve => setTimeout(resolve, 50));
Expand Down Expand Up @@ -313,9 +313,9 @@ describe("StreamableHTTPClientTransport", () => {
await transport.start();
// Type assertion to access private method
const transportWithPrivateMethods = transport as unknown as {
_startOrAuthStandaloneSSE: (lastEventId?: string) => Promise<void>
_startOrAuthSse: (lastEventId?: string) => Promise<void>
};
await transportWithPrivateMethods._startOrAuthStandaloneSSE("test-event-id");
await transportWithPrivateMethods._startOrAuthSse("test-event-id");

// Verify fetch was called with the lastEventId header
expect(fetchSpy).toHaveBeenCalled();
Expand Down Expand Up @@ -382,7 +382,7 @@ describe("StreamableHTTPClientTransport", () => {

await transport.start();

await transport["_startOrAuthStandaloneSSE"]();
await transport["_startOrAuthSse"]();
expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue");

requestInit.headers["X-Custom-Header"] = "SecondCustomValue";
Expand Down
33 changes: 26 additions & 7 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ export type StreamableHTTPClientTransportOptions = {
* Options to configure the reconnection behavior.
*/
reconnectionOptions?: StreamableHTTPReconnectionOptions;
/**
* Session ID for the connection. This is used to identify the session on the server.
* When not provided and connecting to a server that supports session IDs, the server will generate a new session ID.
*/
sessionId?: string;
};

/**
Expand Down Expand Up @@ -98,6 +103,7 @@ export class StreamableHTTPClientTransport implements Transport {
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._reconnectionOptions = opts?.reconnectionOptions || this._defaultReconnectionOptions;
this._sessionId = opts?.sessionId;
}

private async _authThenStart(): Promise<void> {
Expand All @@ -117,7 +123,7 @@ export class StreamableHTTPClientTransport implements Transport {
throw new UnauthorizedError();
}

return await this._startOrAuthStandaloneSSE();
return await this._startOrAuthSse();
}

private async _commonHeaders(): Promise<Headers> {
Expand All @@ -138,7 +144,7 @@ export class StreamableHTTPClientTransport implements Transport {
);
}

private async _startOrAuthStandaloneSSE(lastEventId?: string): Promise<void> {
private async _startOrAuthSse(lastEventId?: string): Promise<void> {
try {
// Try to open an initial SSE stream with GET to listen for server messages
// This is optional according to the spec - server may not support it
Expand Down Expand Up @@ -232,15 +238,15 @@ export class StreamableHTTPClientTransport implements Transport {
// Schedule the reconnection
setTimeout(() => {
// Use the last event ID to resume where we left off
this._startOrAuthStandaloneSSE(lastEventId).catch(error => {
this._startOrAuthSse(lastEventId).catch(error => {
this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`));
// Schedule another attempt if this one failed, incrementing the attempt counter
this._scheduleReconnection(lastEventId, attemptCount + 1);
});
}, delay);
}

private _handleSseStream(stream: ReadableStream<Uint8Array> | null): void {
private _handleSseStream(stream: ReadableStream<Uint8Array> | null, onLastEventIdUpdate?: (event: string) => void): void {
if (!stream) {
return;
}
Expand All @@ -266,6 +272,7 @@ export class StreamableHTTPClientTransport implements Transport {
// Update last event ID if provided
if (event.id) {
lastEventId = event.id;
onLastEventIdUpdate?.(lastEventId);
}

if (!event.event || event.event === "message") {
Expand Down Expand Up @@ -330,8 +337,16 @@ export class StreamableHTTPClientTransport implements Transport {
this.onclose?.();
}

async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise<void> {
async send(message: JSONRPCMessage | JSONRPCMessage[], options?: { lastEventId?: string, onLastEventIdUpdate?: (event: string) => void }): Promise<void> {
try {
// If client passes in a lastEventId in the request options, we need to reconnect the SSE stream
const { lastEventId, onLastEventIdUpdate } = options ?? {};
if (lastEventId) {
// If we have at last event ID, we need to reconnect the SSE stream
this._startOrAuthSse(lastEventId).catch(err => this.onerror?.(err));
return;
}

const headers = await this._commonHeaders();
headers.set("content-type", "application/json");
headers.set("accept", "application/json, text/event-stream");
Expand Down Expand Up @@ -375,7 +390,7 @@ export class StreamableHTTPClientTransport implements Transport {
// if it's supported by the server
if (isJSONRPCNotification(message) && message.method === "notifications/initialized") {
// Start without a lastEventId since this is a fresh connection
this._startOrAuthStandaloneSSE().catch(err => this.onerror?.(err));
this._startOrAuthSse().catch(err => this.onerror?.(err));
}
return;
}
Expand All @@ -393,7 +408,7 @@ export class StreamableHTTPClientTransport implements Transport {
// Handle SSE stream responses for requests
// We use the same handler as standalone streams, which now supports
// reconnection with the last event ID
this._handleSseStream(response.body);
this._handleSseStream(response.body, onLastEventIdUpdate);
} else if (contentType?.includes("application/json")) {
// For non-streaming servers, we might get direct JSON responses
const data = await response.json();
Expand All @@ -416,4 +431,8 @@ export class StreamableHTTPClientTransport implements Transport {
throw error;
}
}

get sessionId(): string | undefined {
return this._sessionId;
}
}
16 changes: 14 additions & 2 deletions src/examples/client/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ let notificationCount = 0;
let client: Client | null = null;
let transport: StreamableHTTPClientTransport | null = null;
let serverUrl = 'http://localhost:3000/mcp';
let notificationsToolLastEventId: string | undefined = undefined;
let sessionId: string | undefined = undefined;

async function main(): Promise<void> {
console.log('MCP Interactive Client');
Expand Down Expand Up @@ -186,7 +188,10 @@ async function connect(url?: string): Promise<void> {
}

transport = new StreamableHTTPClientTransport(
new URL(serverUrl)
new URL(serverUrl),
{
sessionId: sessionId
}
);

// Set up notification handlers
Expand Down Expand Up @@ -218,6 +223,8 @@ async function connect(url?: string): Promise<void> {

// Connect the client
await client.connect(transport);
sessionId = transport.sessionId
console.log('Transport created with session ID:', sessionId);
console.log('Connected to MCP server');
} catch (error) {
console.error('Failed to connect:', error);
Expand Down Expand Up @@ -291,7 +298,12 @@ async function callTool(name: string, args: Record<string, unknown>): Promise<vo
};

console.log(`Calling tool '${name}' with args:`, args);
const result = await client.request(request, CallToolResultSchema);
const onLastEventIdUpdate = (event: string) => {
notificationsToolLastEventId = event;
};
const result = await client.request(request, CallToolResultSchema, {
lastEventId: notificationsToolLastEventId, onLastEventIdUpdate
});

console.log('Tool result:');
result.content.forEach(item => {
Expand Down
20 changes: 12 additions & 8 deletions src/examples/server/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,18 @@ server.tool(

while (count === 0 || counter < count) {
counter++;
await sendNotification({
method: "notifications/message",
params: {
level: "info",
data: `Periodic notification #${counter} at ${new Date().toISOString()}`
}
});

try {
await sendNotification({
method: "notifications/message",
params: {
level: "info",
data: `Periodic notification #${counter} at ${new Date().toISOString()}`
}
});
}
catch (error) {
console.error("Error sending notification:", error);
}
// Wait for the specified interval
await sleep(interval);
}
Expand Down
Loading