@@ -1293,4 +1293,265 @@ describe("StreamableHTTPServerTransport in stateless mode", () => {
12931293 } ) ;
12941294 expect ( stream2 . status ) . toBe ( 409 ) ; // Conflict - only one stream allowed
12951295 } ) ;
1296- } ) ;
1296+ } ) ;
1297+
1298+ // Test DNS rebinding protection
1299+ describe ( "StreamableHTTPServerTransport DNS rebinding protection" , ( ) => {
1300+ let server : Server ;
1301+ let transport : StreamableHTTPServerTransport ;
1302+ let baseUrl : URL ;
1303+
1304+ afterEach ( async ( ) => {
1305+ if ( server && transport ) {
1306+ await stopTestServer ( { server, transport } ) ;
1307+ }
1308+ } ) ;
1309+
1310+ describe ( "Host header validation" , ( ) => {
1311+ it ( "should accept requests with allowed host headers" , async ( ) => {
1312+ const result = await createTestServerWithDnsProtection ( {
1313+ sessionIdGenerator : undefined ,
1314+ allowedHosts : [ 'localhost:3001' ] ,
1315+ disableDnsRebindingProtection : false ,
1316+ } ) ;
1317+ server = result . server ;
1318+ transport = result . transport ;
1319+ baseUrl = result . baseUrl ;
1320+
1321+ // Note: fetch() automatically sets Host header to match the URL
1322+ // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work
1323+ const response = await fetch ( baseUrl , {
1324+ method : "POST" ,
1325+ headers : {
1326+ "Content-Type" : "application/json" ,
1327+ Accept : "application/json, text/event-stream" ,
1328+ } ,
1329+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1330+ } ) ;
1331+
1332+ expect ( response . status ) . toBe ( 200 ) ;
1333+ } ) ;
1334+
1335+ it ( "should reject requests with disallowed host headers" , async ( ) => {
1336+ // Test DNS rebinding protection by creating a server that only allows example.com
1337+ // but we're connecting via localhost, so it should be rejected
1338+ const result = await createTestServerWithDnsProtection ( {
1339+ sessionIdGenerator : undefined ,
1340+ allowedHosts : [ 'example.com:3001' ] ,
1341+ disableDnsRebindingProtection : false ,
1342+ } ) ;
1343+ server = result . server ;
1344+ transport = result . transport ;
1345+ baseUrl = result . baseUrl ;
1346+
1347+ const response = await fetch ( baseUrl , {
1348+ method : "POST" ,
1349+ headers : {
1350+ "Content-Type" : "application/json" ,
1351+ Accept : "application/json, text/event-stream" ,
1352+ } ,
1353+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1354+ } ) ;
1355+
1356+ expect ( response . status ) . toBe ( 403 ) ;
1357+ const body = await response . json ( ) ;
1358+ expect ( body . error . message ) . toContain ( "Invalid Host header:" ) ;
1359+ } ) ;
1360+
1361+ it ( "should reject GET requests with disallowed host headers" , async ( ) => {
1362+ const result = await createTestServerWithDnsProtection ( {
1363+ sessionIdGenerator : undefined ,
1364+ allowedHosts : [ 'example.com:3001' ] ,
1365+ disableDnsRebindingProtection : false ,
1366+ } ) ;
1367+ server = result . server ;
1368+ transport = result . transport ;
1369+ baseUrl = result . baseUrl ;
1370+
1371+ const response = await fetch ( baseUrl , {
1372+ method : "GET" ,
1373+ headers : {
1374+ Accept : "text/event-stream" ,
1375+ } ,
1376+ } ) ;
1377+
1378+ expect ( response . status ) . toBe ( 403 ) ;
1379+ } ) ;
1380+ } ) ;
1381+
1382+ describe ( "Origin header validation" , ( ) => {
1383+ it ( "should accept requests with allowed origin headers" , async ( ) => {
1384+ const result = await createTestServerWithDnsProtection ( {
1385+ sessionIdGenerator : undefined ,
1386+ allowedOrigins : [ 'http://localhost:3000' , 'https://example.com' ] ,
1387+ disableDnsRebindingProtection : false ,
1388+ } ) ;
1389+ server = result . server ;
1390+ transport = result . transport ;
1391+ baseUrl = result . baseUrl ;
1392+
1393+ const response = await fetch ( baseUrl , {
1394+ method : "POST" ,
1395+ headers : {
1396+ "Content-Type" : "application/json" ,
1397+ Accept : "application/json, text/event-stream" ,
1398+ Origin : "http://localhost:3000" ,
1399+ } ,
1400+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1401+ } ) ;
1402+
1403+ expect ( response . status ) . toBe ( 200 ) ;
1404+ } ) ;
1405+
1406+ it ( "should reject requests with disallowed origin headers" , async ( ) => {
1407+ const result = await createTestServerWithDnsProtection ( {
1408+ sessionIdGenerator : undefined ,
1409+ allowedOrigins : [ 'http://localhost:3000' ] ,
1410+ disableDnsRebindingProtection : false ,
1411+ } ) ;
1412+ server = result . server ;
1413+ transport = result . transport ;
1414+ baseUrl = result . baseUrl ;
1415+
1416+ const response = await fetch ( baseUrl , {
1417+ method : "POST" ,
1418+ headers : {
1419+ "Content-Type" : "application/json" ,
1420+ Accept : "application/json, text/event-stream" ,
1421+ Origin : "http://evil.com" ,
1422+ } ,
1423+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1424+ } ) ;
1425+
1426+ expect ( response . status ) . toBe ( 403 ) ;
1427+ const body = await response . json ( ) ;
1428+ expect ( body . error . message ) . toBe ( "Invalid Origin header: http://evil.com" ) ;
1429+ } ) ;
1430+ } ) ;
1431+
1432+ describe ( "disableDnsRebindingProtection option" , ( ) => {
1433+ it ( "should skip all validations when disableDnsRebindingProtection is true" , async ( ) => {
1434+ const result = await createTestServerWithDnsProtection ( {
1435+ sessionIdGenerator : undefined ,
1436+ allowedHosts : [ 'localhost:3001' ] ,
1437+ allowedOrigins : [ 'http://localhost:3000' ] ,
1438+ disableDnsRebindingProtection : true ,
1439+ } ) ;
1440+ server = result . server ;
1441+ transport = result . transport ;
1442+ baseUrl = result . baseUrl ;
1443+
1444+ const response = await fetch ( baseUrl , {
1445+ method : "POST" ,
1446+ headers : {
1447+ "Content-Type" : "application/json" ,
1448+ Accept : "application/json, text/event-stream" ,
1449+ Host : "evil.com" ,
1450+ Origin : "http://evil.com" ,
1451+ } ,
1452+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1453+ } ) ;
1454+
1455+ // Should pass even with invalid headers because protection is disabled
1456+ expect ( response . status ) . toBe ( 200 ) ;
1457+ } ) ;
1458+ } ) ;
1459+
1460+ describe ( "Combined validations" , ( ) => {
1461+ it ( "should validate both host and origin when both are configured" , async ( ) => {
1462+ const result = await createTestServerWithDnsProtection ( {
1463+ sessionIdGenerator : undefined ,
1464+ allowedHosts : [ 'localhost:3001' ] ,
1465+ allowedOrigins : [ 'http://localhost:3001' ] ,
1466+ disableDnsRebindingProtection : false ,
1467+ } ) ;
1468+ server = result . server ;
1469+ transport = result . transport ;
1470+ baseUrl = result . baseUrl ;
1471+
1472+ // Test with invalid origin (host will be automatically correct via fetch)
1473+ const response1 = await fetch ( baseUrl , {
1474+ method : "POST" ,
1475+ headers : {
1476+ "Content-Type" : "application/json" ,
1477+ Accept : "application/json, text/event-stream" ,
1478+ Origin : "http://evil.com" ,
1479+ } ,
1480+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1481+ } ) ;
1482+
1483+ expect ( response1 . status ) . toBe ( 403 ) ;
1484+ const body1 = await response1 . json ( ) ;
1485+ expect ( body1 . error . message ) . toBe ( "Invalid Origin header: http://evil.com" ) ;
1486+
1487+ // Test with valid origin
1488+ const response2 = await fetch ( baseUrl , {
1489+ method : "POST" ,
1490+ headers : {
1491+ "Content-Type" : "application/json" ,
1492+ Accept : "application/json, text/event-stream" ,
1493+ Origin : "http://localhost:3001" ,
1494+ } ,
1495+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1496+ } ) ;
1497+
1498+ expect ( response2 . status ) . toBe ( 200 ) ;
1499+ } ) ;
1500+ } ) ;
1501+ } ) ;
1502+
1503+ /**
1504+ * Helper to create test server with DNS rebinding protection options
1505+ */
1506+ async function createTestServerWithDnsProtection ( config : {
1507+ sessionIdGenerator : ( ( ) => string ) | undefined ;
1508+ allowedHosts ?: string [ ] ;
1509+ allowedOrigins ?: string [ ] ;
1510+ disableDnsRebindingProtection ?: boolean ;
1511+ } ) : Promise < {
1512+ server : Server ;
1513+ transport : StreamableHTTPServerTransport ;
1514+ mcpServer : McpServer ;
1515+ baseUrl : URL ;
1516+ } > {
1517+ const mcpServer = new McpServer (
1518+ { name : "test-server" , version : "1.0.0" } ,
1519+ { capabilities : { logging : { } } }
1520+ ) ;
1521+
1522+ const transport = new StreamableHTTPServerTransport ( {
1523+ sessionIdGenerator : config . sessionIdGenerator ,
1524+ allowedHosts : config . allowedHosts ,
1525+ allowedOrigins : config . allowedOrigins ,
1526+ disableDnsRebindingProtection : config . disableDnsRebindingProtection ,
1527+ } ) ;
1528+
1529+ await mcpServer . connect ( transport ) ;
1530+
1531+ const httpServer = createServer ( async ( req , res ) => {
1532+ if ( req . method === "POST" ) {
1533+ let body = "" ;
1534+ req . on ( "data" , ( chunk ) => ( body += chunk ) ) ;
1535+ req . on ( "end" , async ( ) => {
1536+ const parsedBody = JSON . parse ( body ) ;
1537+ await transport . handleRequest ( req as IncomingMessage & { auth ?: AuthInfo } , res , parsedBody ) ;
1538+ } ) ;
1539+ } else {
1540+ await transport . handleRequest ( req as IncomingMessage & { auth ?: AuthInfo } , res ) ;
1541+ }
1542+ } ) ;
1543+
1544+ await new Promise < void > ( ( resolve ) => {
1545+ httpServer . listen ( 3001 , ( ) => resolve ( ) ) ;
1546+ } ) ;
1547+
1548+ const port = ( httpServer . address ( ) as AddressInfo ) . port ;
1549+ const serverUrl = new URL ( `http://localhost:${ port } /` ) ;
1550+
1551+ return {
1552+ server : httpServer ,
1553+ transport,
1554+ mcpServer,
1555+ baseUrl : serverUrl ,
1556+ } ;
1557+ }
0 commit comments