Skip to content

Commit 881f2b8

Browse files
authored
Merge branch 'main' into fix-normalised-elitication-content
2 parents d597cbe + 5ceabfb commit 881f2b8

File tree

4 files changed

+241
-25
lines changed

4 files changed

+241
-25
lines changed

src/client/sse.test.ts

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,52 @@ describe('SSEClientTransport', () => {
308308

309309
await transport.start();
310310

311-
// Store original fetch
312311
const originalFetch = global.fetch;
312+
try {
313+
global.fetch = vi.fn().mockResolvedValue({ ok: true });
314+
315+
const message: JSONRPCMessage = {
316+
jsonrpc: '2.0',
317+
id: '1',
318+
method: 'test',
319+
params: {}
320+
};
321+
322+
await transport.send(message);
323+
324+
const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers;
325+
expect(calledHeaders.get('Authorization')).toBe('Bearer test-token');
326+
expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value');
327+
expect(calledHeaders.get('content-type')).toBe('application/json');
328+
329+
customHeaders['X-Custom-Header'] = 'updated-value';
330+
331+
await transport.send(message);
332+
333+
const updatedHeaders = (global.fetch as Mock).mock.calls[1][1].headers;
334+
expect(updatedHeaders.get('X-Custom-Header')).toBe('updated-value');
335+
} finally {
336+
global.fetch = originalFetch;
337+
}
338+
});
339+
340+
it('passes custom headers to fetch requests (Headers class)', async () => {
341+
const customHeaders = new Headers({
342+
Authorization: 'Bearer test-token',
343+
'X-Custom-Header': 'custom-value'
344+
});
345+
346+
transport = new SSEClientTransport(resourceBaseUrl, {
347+
requestInit: {
348+
headers: customHeaders
349+
}
350+
});
351+
352+
await transport.start();
313353

354+
const originalFetch = global.fetch;
314355
try {
315-
// Mock fetch for the message sending test
316-
global.fetch = vi.fn().mockResolvedValue({
317-
ok: true
318-
});
356+
global.fetch = vi.fn().mockResolvedValue({ ok: true });
319357

320358
const message: JSONRPCMessage = {
321359
jsonrpc: '2.0',
@@ -326,20 +364,45 @@ describe('SSEClientTransport', () => {
326364

327365
await transport.send(message);
328366

329-
// Verify fetch was called with correct headers
330-
expect(global.fetch).toHaveBeenCalledWith(
331-
expect.any(URL),
332-
expect.objectContaining({
333-
headers: expect.any(Headers)
334-
})
335-
);
367+
const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers;
368+
expect(calledHeaders.get('Authorization')).toBe('Bearer test-token');
369+
expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value');
370+
expect(calledHeaders.get('content-type')).toBe('application/json');
371+
372+
customHeaders.set('X-Custom-Header', 'updated-value');
373+
374+
await transport.send(message);
375+
376+
const updatedHeaders = (global.fetch as Mock).mock.calls[1][1].headers;
377+
expect(updatedHeaders.get('X-Custom-Header')).toBe('updated-value');
378+
} finally {
379+
global.fetch = originalFetch;
380+
}
381+
});
382+
383+
it('passes custom headers to fetch requests (array of tuples)', async () => {
384+
transport = new SSEClientTransport(resourceBaseUrl, {
385+
requestInit: {
386+
headers: [
387+
['Authorization', 'Bearer test-token'],
388+
['X-Custom-Header', 'custom-value']
389+
]
390+
}
391+
});
392+
393+
await transport.start();
394+
395+
const originalFetch = global.fetch;
396+
try {
397+
global.fetch = vi.fn().mockResolvedValue({ ok: true });
398+
399+
await transport.send({ jsonrpc: '2.0', id: '1', method: 'test', params: {} });
336400

337401
const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers;
338-
expect(calledHeaders.get('Authorization')).toBe(customHeaders.Authorization);
339-
expect(calledHeaders.get('X-Custom-Header')).toBe(customHeaders['X-Custom-Header']);
402+
expect(calledHeaders.get('Authorization')).toBe('Bearer test-token');
403+
expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value');
340404
expect(calledHeaders.get('content-type')).toBe('application/json');
341405
} finally {
342-
// Restore original fetch
343406
global.fetch = originalFetch;
344407
}
345408
});

src/client/sse.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { EventSource, type ErrorEvent, type EventSourceInit } from 'eventsource';
2-
import { Transport, FetchLike, createFetchWithInit } from '../shared/transport.js';
2+
import { Transport, FetchLike, createFetchWithInit, normalizeHeaders } from '../shared/transport.js';
33
import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
44
import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';
55

@@ -114,7 +114,7 @@ export class SSEClientTransport implements Transport {
114114
}
115115

116116
private async _commonHeaders(): Promise<Headers> {
117-
const headers: HeadersInit = {};
117+
const headers: HeadersInit & Record<string, string> = {};
118118
if (this._authProvider) {
119119
const tokens = await this._authProvider.tokens();
120120
if (tokens) {
@@ -125,7 +125,12 @@ export class SSEClientTransport implements Transport {
125125
headers['mcp-protocol-version'] = this._protocolVersion;
126126
}
127127

128-
return new Headers({ ...headers, ...this._requestInit?.headers });
128+
const extraHeaders = normalizeHeaders(this._requestInit?.headers);
129+
130+
return new Headers({
131+
...headers,
132+
...extraHeaders
133+
});
129134
}
130135

131136
private _startOrAuth(): Promise<void> {

src/client/streamableHttp.test.ts

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ describe('StreamableHTTPClientTransport', () => {
480480
it('should always send specified custom headers', async () => {
481481
const requestInit = {
482482
headers: {
483+
Authorization: 'Bearer test-token',
483484
'X-Custom-Header': 'CustomValue'
484485
}
485486
};
@@ -497,6 +498,7 @@ describe('StreamableHTTPClientTransport', () => {
497498
await transport.start();
498499

499500
await transport['_startOrAuthSse']({});
501+
expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token');
500502
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');
501503

502504
requestInit.headers['X-Custom-Header'] = 'SecondCustomValue';
@@ -510,6 +512,7 @@ describe('StreamableHTTPClientTransport', () => {
510512
it('should always send specified custom headers (Headers class)', async () => {
511513
const requestInit = {
512514
headers: new Headers({
515+
Authorization: 'Bearer test-token',
513516
'X-Custom-Header': 'CustomValue'
514517
})
515518
};
@@ -527,6 +530,7 @@ describe('StreamableHTTPClientTransport', () => {
527530
await transport.start();
528531

529532
await transport['_startOrAuthSse']({});
533+
expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token');
530534
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');
531535

532536
(requestInit.headers as Headers).set('X-Custom-Header', 'SecondCustomValue');
@@ -537,6 +541,30 @@ describe('StreamableHTTPClientTransport', () => {
537541
expect(global.fetch).toHaveBeenCalledTimes(2);
538542
});
539543

544+
it('should always send specified custom headers (array of tuples)', async () => {
545+
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
546+
requestInit: {
547+
headers: [
548+
['Authorization', 'Bearer test-token'],
549+
['X-Custom-Header', 'CustomValue']
550+
]
551+
}
552+
});
553+
554+
let actualReqInit: RequestInit = {};
555+
556+
(global.fetch as Mock).mockImplementation(async (_url, reqInit) => {
557+
actualReqInit = reqInit;
558+
return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } });
559+
});
560+
561+
await transport.start();
562+
563+
await transport['_startOrAuthSse']({});
564+
expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token');
565+
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');
566+
});
567+
540568
it('should have exponential backoff with configurable maxRetries', () => {
541569
// This test verifies the maxRetries and backoff calculation directly
542570

@@ -866,6 +894,112 @@ describe('StreamableHTTPClientTransport', () => {
866894
expect(reconnectHeaders.get('last-event-id')).toBe('event-123');
867895
});
868896

897+
it('should NOT reconnect a POST stream when response was received', async () => {
898+
// ARRANGE
899+
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
900+
reconnectionOptions: {
901+
initialReconnectionDelay: 10,
902+
maxRetries: 1,
903+
maxReconnectionDelay: 1000,
904+
reconnectionDelayGrowFactor: 1
905+
}
906+
});
907+
908+
// Create a stream that sends:
909+
// 1. Priming event with ID (enables potential reconnection)
910+
// 2. The actual response (should prevent reconnection)
911+
// 3. Then closes
912+
const streamWithResponse = new ReadableStream({
913+
start(controller) {
914+
// Priming event with ID
915+
controller.enqueue(new TextEncoder().encode('id: priming-123\ndata: \n\n'));
916+
// The actual response to the request
917+
controller.enqueue(
918+
new TextEncoder().encode('id: response-456\ndata: {"jsonrpc":"2.0","result":{"tools":[]},"id":"request-1"}\n\n')
919+
);
920+
// Stream closes normally
921+
controller.close();
922+
}
923+
});
924+
925+
const fetchMock = global.fetch as Mock;
926+
fetchMock.mockResolvedValueOnce({
927+
ok: true,
928+
status: 200,
929+
headers: new Headers({ 'content-type': 'text/event-stream' }),
930+
body: streamWithResponse
931+
});
932+
933+
const requestMessage: JSONRPCRequest = {
934+
jsonrpc: '2.0',
935+
method: 'tools/list',
936+
id: 'request-1',
937+
params: {}
938+
};
939+
940+
// ACT
941+
await transport.start();
942+
await transport.send(requestMessage);
943+
await vi.advanceTimersByTimeAsync(50);
944+
945+
// ASSERT
946+
// THE KEY ASSERTION: Fetch was called ONCE only - no reconnection!
947+
// The response was received, so no need to reconnect.
948+
expect(fetchMock).toHaveBeenCalledTimes(1);
949+
expect(fetchMock.mock.calls[0][1]?.method).toBe('POST');
950+
});
951+
952+
it('should not attempt reconnection after close() is called', async () => {
953+
// ARRANGE
954+
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
955+
reconnectionOptions: {
956+
initialReconnectionDelay: 100,
957+
maxRetries: 3,
958+
maxReconnectionDelay: 1000,
959+
reconnectionDelayGrowFactor: 1
960+
}
961+
});
962+
963+
// Stream with priming event + notification (no response) that closes
964+
// This triggers reconnection scheduling
965+
const streamWithPriming = new ReadableStream({
966+
start(controller) {
967+
controller.enqueue(
968+
new TextEncoder().encode('id: event-123\ndata: {"jsonrpc":"2.0","method":"notifications/test","params":{}}\n\n')
969+
);
970+
controller.close();
971+
}
972+
});
973+
974+
const fetchMock = global.fetch as Mock;
975+
976+
// POST request returns streaming response
977+
fetchMock.mockResolvedValueOnce({
978+
ok: true,
979+
status: 200,
980+
headers: new Headers({ 'content-type': 'text/event-stream' }),
981+
body: streamWithPriming
982+
});
983+
984+
// ACT
985+
await transport.start();
986+
await transport.send({ jsonrpc: '2.0', method: 'test', id: '1', params: {} });
987+
988+
// Wait a tick to let stream processing complete and schedule reconnection
989+
await vi.advanceTimersByTimeAsync(10);
990+
991+
// Now close() - reconnection timeout is pending (scheduled for 100ms)
992+
await transport.close();
993+
994+
// Advance past reconnection delay
995+
await vi.advanceTimersByTimeAsync(200);
996+
997+
// ASSERT
998+
// Only 1 call: the initial POST. No reconnection attempts after close().
999+
expect(fetchMock).toHaveBeenCalledTimes(1);
1000+
expect(fetchMock.mock.calls[0][1]?.method).toBe('POST');
1001+
});
1002+
8691003
it('should not throw JSON parse error on priming events with empty data', async () => {
8701004
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'));
8711005

src/client/streamableHttp.ts

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ export class StreamableHTTPClientTransport implements Transport {
136136
private _hasCompletedAuthFlow = false; // Circuit breaker: detect auth success followed by immediate 401
137137
private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping.
138138
private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field
139+
private _reconnectionTimeout?: ReturnType<typeof setTimeout>;
139140

140141
onclose?: () => void;
141142
onerror?: (error: Error) => void;
@@ -287,7 +288,7 @@ export class StreamableHTTPClientTransport implements Transport {
287288
const delay = this._getNextReconnectionDelay(attemptCount);
288289

289290
// Schedule the reconnection
290-
setTimeout(() => {
291+
this._reconnectionTimeout = setTimeout(() => {
291292
// Use the last event ID to resume where we left off
292293
this._startOrAuthSse(options).catch(error => {
293294
this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`));
@@ -307,6 +308,9 @@ export class StreamableHTTPClientTransport implements Transport {
307308
// Track whether we've received a priming event (event with ID)
308309
// Per spec, server SHOULD send a priming event with ID before closing
309310
let hasPrimingEvent = false;
311+
// Track whether we've received a response - if so, no need to reconnect
312+
// Reconnection is for when server disconnects BEFORE sending response
313+
let receivedResponse = false;
310314
const processStream = async () => {
311315
// this is the closest we can get to trying to catch network errors
312316
// if something happens reader will throw
@@ -346,8 +350,12 @@ export class StreamableHTTPClientTransport implements Transport {
346350
if (!event.event || event.event === 'message') {
347351
try {
348352
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
349-
if (replayMessageId !== undefined && isJSONRPCResponse(message)) {
350-
message.id = replayMessageId;
353+
if (isJSONRPCResponse(message)) {
354+
// Mark that we received a response - no need to reconnect for this request
355+
receivedResponse = true;
356+
if (replayMessageId !== undefined) {
357+
message.id = replayMessageId;
358+
}
351359
}
352360
this.onmessage?.(message);
353361
} catch (error) {
@@ -359,8 +367,10 @@ export class StreamableHTTPClientTransport implements Transport {
359367
// Handle graceful server-side disconnect
360368
// Server may close connection after sending event ID and retry field
361369
// Reconnect if: already reconnectable (GET stream) OR received a priming event (POST stream with event ID)
370+
// BUT don't reconnect if we already received a response - the request is complete
362371
const canResume = isReconnectable || hasPrimingEvent;
363-
if (canResume && this._abortController && !this._abortController.signal.aborted) {
372+
const needsReconnect = canResume && !receivedResponse;
373+
if (needsReconnect && this._abortController && !this._abortController.signal.aborted) {
364374
this._scheduleReconnection(
365375
{
366376
resumptionToken: lastEventId,
@@ -376,8 +386,10 @@ export class StreamableHTTPClientTransport implements Transport {
376386

377387
// Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing
378388
// Reconnect if: already reconnectable (GET stream) OR received a priming event (POST stream with event ID)
389+
// BUT don't reconnect if we already received a response - the request is complete
379390
const canResume = isReconnectable || hasPrimingEvent;
380-
if (canResume && this._abortController && !this._abortController.signal.aborted) {
391+
const needsReconnect = canResume && !receivedResponse;
392+
if (needsReconnect && this._abortController && !this._abortController.signal.aborted) {
381393
// Use the exponential backoff reconnection strategy
382394
try {
383395
this._scheduleReconnection(
@@ -428,9 +440,11 @@ export class StreamableHTTPClientTransport implements Transport {
428440
}
429441

430442
async close(): Promise<void> {
431-
// Abort any pending requests
443+
if (this._reconnectionTimeout) {
444+
clearTimeout(this._reconnectionTimeout);
445+
this._reconnectionTimeout = undefined;
446+
}
432447
this._abortController?.abort();
433-
434448
this.onclose?.();
435449
}
436450

0 commit comments

Comments
 (0)