@@ -98,7 +98,7 @@ describe("StreamableHTTPServerTransport", () => {
9898
9999 await transport . handleRequest ( req , mockResponse ) ;
100100
101- expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 404 ) ;
101+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 404 , { } ) ;
102102 // check if the error response is a valid JSON-RPC error format
103103 expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"jsonrpc":"2.0"' ) ) ;
104104 expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"error"' ) ) ;
@@ -115,7 +115,7 @@ describe("StreamableHTTPServerTransport", () => {
115115
116116 await transport . handleRequest ( req , mockResponse ) ;
117117
118- expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
118+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 400 , { } ) ;
119119 expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"jsonrpc":"2.0"' ) ) ;
120120 expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"message":"Bad Request: Mcp-Session-Id header is required"' ) ) ;
121121 } ) ;
@@ -342,7 +342,7 @@ describe("StreamableHTTPServerTransport", () => {
342342
343343 await transport . handleRequest ( req , mockResponse ) ;
344344
345- expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 406 ) ;
345+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 406 , { } ) ;
346346 expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"jsonrpc":"2.0"' ) ) ;
347347 } ) ;
348348
@@ -788,4 +788,141 @@ describe("StreamableHTTPServerTransport", () => {
788788 expect ( onMessageMock ) . not . toHaveBeenCalledWith ( requestBodyMessage ) ;
789789 } ) ;
790790 } ) ;
791+
792+ describe ( "Custom Headers" , ( ) => {
793+ const customHeaders = {
794+ "X-Custom-Header" : "custom-value" ,
795+ "X-API-Version" : "1.0" ,
796+ "Access-Control-Allow-Origin" : "*"
797+ } ;
798+
799+ let transportWithHeaders : StreamableHTTPServerTransport ;
800+ let mockResponse : jest . Mocked < ServerResponse > ;
801+
802+ beforeEach ( ( ) => {
803+ transportWithHeaders = new StreamableHTTPServerTransport ( endpoint , { customHeaders } ) ;
804+ mockResponse = createMockResponse ( ) ;
805+ } ) ;
806+
807+ it ( "should include custom headers in SSE response" , async ( ) => {
808+ const req = createMockRequest ( {
809+ method : "GET" ,
810+ headers : {
811+ accept : "text/event-stream" ,
812+ "mcp-session-id" : transportWithHeaders . sessionId
813+ } ,
814+ } ) ;
815+
816+ await transportWithHeaders . handleRequest ( req , mockResponse ) ;
817+
818+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
819+ 200 ,
820+ expect . objectContaining ( {
821+ ...customHeaders ,
822+ "Content-Type" : "text/event-stream" ,
823+ "Cache-Control" : "no-cache" ,
824+ "Connection" : "keep-alive" ,
825+ "mcp-session-id" : transportWithHeaders . sessionId
826+ } )
827+ ) ;
828+ } ) ;
829+
830+ it ( "should include custom headers in JSON response" , async ( ) => {
831+ const message : JSONRPCMessage = {
832+ jsonrpc : "2.0" ,
833+ method : "test" ,
834+ params : { } ,
835+ id : 1 ,
836+ } ;
837+
838+ const req = createMockRequest ( {
839+ method : "POST" ,
840+ headers : {
841+ "content-type" : "application/json" ,
842+ "accept" : "application/json" ,
843+ "mcp-session-id" : transportWithHeaders . sessionId
844+ } ,
845+ body : JSON . stringify ( message ) ,
846+ } ) ;
847+
848+ await transportWithHeaders . handleRequest ( req , mockResponse ) ;
849+
850+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
851+ 200 ,
852+ expect . objectContaining ( {
853+ ...customHeaders ,
854+ "Content-Type" : "application/json" ,
855+ "mcp-session-id" : transportWithHeaders . sessionId
856+ } )
857+ ) ;
858+ } ) ;
859+
860+ it ( "should include custom headers in error responses" , async ( ) => {
861+ const req = createMockRequest ( {
862+ method : "GET" ,
863+ headers : {
864+ accept : "text/event-stream" ,
865+ "mcp-session-id" : "invalid-session-id"
866+ } ,
867+ } ) ;
868+
869+ await transportWithHeaders . handleRequest ( req , mockResponse ) ;
870+
871+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
872+ 404 ,
873+ expect . objectContaining ( customHeaders )
874+ ) ;
875+ } ) ;
876+
877+ it ( "should not override essential headers with custom headers" , async ( ) => {
878+ const transportWithConflictingHeaders = new StreamableHTTPServerTransport ( endpoint , {
879+ customHeaders : {
880+ "Content-Type" : "text/plain" , // 尝试覆盖必要的 Content-Type 头
881+ "X-Custom-Header" : "custom-value"
882+ }
883+ } ) ;
884+
885+ const req = createMockRequest ( {
886+ method : "GET" ,
887+ headers : {
888+ accept : "text/event-stream" ,
889+ "mcp-session-id" : transportWithConflictingHeaders . sessionId
890+ } ,
891+ } ) ;
892+
893+ await transportWithConflictingHeaders . handleRequest ( req , mockResponse ) ;
894+
895+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
896+ 200 ,
897+ expect . objectContaining ( {
898+ "Content-Type" : "text/event-stream" , // 应该保持原有的 Content-Type
899+ "X-Custom-Header" : "custom-value"
900+ } )
901+ ) ;
902+ } ) ;
903+
904+ it ( "should work with empty custom headers" , async ( ) => {
905+ const transportWithoutHeaders = new StreamableHTTPServerTransport ( endpoint ) ;
906+
907+ const req = createMockRequest ( {
908+ method : "GET" ,
909+ headers : {
910+ accept : "text/event-stream" ,
911+ "mcp-session-id" : transportWithoutHeaders . sessionId
912+ } ,
913+ } ) ;
914+
915+ await transportWithoutHeaders . handleRequest ( req , mockResponse ) ;
916+
917+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
918+ 200 ,
919+ expect . objectContaining ( {
920+ "Content-Type" : "text/event-stream" ,
921+ "Cache-Control" : "no-cache" ,
922+ "Connection" : "keep-alive" ,
923+ "mcp-session-id" : transportWithoutHeaders . sessionId
924+ } )
925+ ) ;
926+ } ) ;
927+ } ) ;
791928} ) ;
0 commit comments