diff --git a/client/src/App.tsx b/client/src/App.tsx index 39fc2812a..421e75d07 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -32,6 +32,7 @@ import { } from "@/utils/metaUtils"; import { AuthDebuggerState, EMPTY_DEBUGGER_STATE } from "./lib/auth-types"; import { OAuthStateMachine } from "./lib/oauth-state-machine"; +import { createProxyFetch } from "./lib/proxyFetch"; import { cacheToolOutputSchemas } from "./utils/schemaUtils"; import { cleanParams } from "./utils/paramUtils"; import type { JsonSchemaType } from "./utils/jsonUtils"; @@ -581,9 +582,17 @@ const App = () => { }; try { - const stateMachine = new OAuthStateMachine(sseUrl, (updates) => { - currentState = { ...currentState, ...updates }; - }); + const fetchFn = + connectionType === "proxy" && config + ? createProxyFetch(config) + : undefined; + const stateMachine = new OAuthStateMachine( + sseUrl, + (updates) => { + currentState = { ...currentState, ...updates }; + }, + fetchFn, + ); while ( currentState.oauthStep !== "complete" && @@ -621,7 +630,7 @@ const App = () => { }); } }, - [sseUrl], + [sseUrl, connectionType, config], ); useEffect(() => { @@ -1184,6 +1193,8 @@ const App = () => { onBack={() => setIsAuthDebuggerVisible(false)} authState={authState} updateAuthState={updateAuthState} + config={config} + connectionType={connectionType} /> ); diff --git a/client/src/__tests__/proxyFetchEndpoint.test.ts b/client/src/__tests__/proxyFetchEndpoint.test.ts new file mode 100644 index 000000000..9deec8f65 --- /dev/null +++ b/client/src/__tests__/proxyFetchEndpoint.test.ts @@ -0,0 +1,92 @@ +/** + * Tests for the proxy server's POST /fetch endpoint. + * Spawns the server and hits it like any other HTTP client would. + */ +import { spawn, type ChildProcess } from "child_process"; +import { resolve } from "path"; + +const TEST_PORT = 16321; +const TEST_TOKEN = "test-proxy-token-12345"; +const SERVER_PATH = resolve(__dirname, "../../../server/build/index.js"); + +async function waitForServer(baseUrl: string, maxWaitMs = 5000): Promise { + const start = Date.now(); + while (Date.now() - start < maxWaitMs) { + try { + const res = await fetch(`${baseUrl}/health`); + if (res.ok) return; + } catch { + await new Promise((r) => setTimeout(r, 50)); + } + } + throw new Error("Server did not become ready"); +} + +describe("POST /fetch endpoint", () => { + let server: ChildProcess; + const baseUrl = `http://localhost:${TEST_PORT}`; + + beforeAll(async () => { + server = spawn("node", [SERVER_PATH], { + env: { + ...process.env, + SERVER_PORT: String(TEST_PORT), + MCP_PROXY_AUTH_TOKEN: TEST_TOKEN, + }, + stdio: "ignore", + }); + await waitForServer(baseUrl); + }, 10000); + + afterAll(() => { + server.kill(); + }); + + it("returns 401 when no auth header", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + url: "https://example.com/", + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error).toBe("Unauthorized"); + }); + + it("returns 401 when auth token is invalid", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": "Bearer wrong-token", + }, + body: JSON.stringify({ + url: "https://example.com/", + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(401); + }); + + it("forwards request when auth token is valid", async () => { + const res = await fetch(`${baseUrl}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": `Bearer ${TEST_TOKEN}`, + }, + body: JSON.stringify({ + url: "https://example.com/", + init: { method: "GET" }, + }), + }); + expect(res.status).toBe(200); + const body = await res.json(); + expect(body.ok).toBe(true); + expect(body.status).toBe(200); + expect(body.body).toBeDefined(); + }); +}); diff --git a/client/src/components/AuthDebugger.tsx b/client/src/components/AuthDebugger.tsx index 6252c1161..cd1c7e222 100644 --- a/client/src/components/AuthDebugger.tsx +++ b/client/src/components/AuthDebugger.tsx @@ -5,14 +5,18 @@ import { AlertCircle } from "lucide-react"; import { AuthDebuggerState, EMPTY_DEBUGGER_STATE } from "../lib/auth-types"; import { OAuthFlowProgress } from "./OAuthFlowProgress"; import { OAuthStateMachine } from "../lib/oauth-state-machine"; +import { createProxyFetch } from "../lib/proxyFetch"; import { SESSION_KEYS } from "../lib/constants"; import { validateRedirectUrl } from "@/utils/urlValidation"; +import type { InspectorConfig } from "../lib/configurationTypes"; export interface AuthDebuggerProps { serverUrl: string; onBack: () => void; authState: AuthDebuggerState; updateAuthState: (updates: Partial) => void; + config?: InspectorConfig; + connectionType?: "direct" | "proxy"; } interface StatusMessageProps { @@ -60,6 +64,8 @@ const AuthDebugger = ({ onBack, authState, updateAuthState, + config, + connectionType, }: AuthDebuggerProps) => { // Check for existing tokens on mount useEffect(() => { @@ -102,9 +108,12 @@ const AuthDebugger = ({ }); }, [serverUrl, updateAuthState]); + const fetchFn = + connectionType === "proxy" && config ? createProxyFetch(config) : undefined; + const stateMachine = useMemo( - () => new OAuthStateMachine(serverUrl, updateAuthState), - [serverUrl, updateAuthState], + () => new OAuthStateMachine(serverUrl, updateAuthState, fetchFn), + [serverUrl, updateAuthState, fetchFn], ); const proceedToNextStep = useCallback(async () => { @@ -150,11 +159,15 @@ const AuthDebugger = ({ latestError: null, }; - const oauthMachine = new OAuthStateMachine(serverUrl, (updates) => { - // Update our temporary state during the process - currentState = { ...currentState, ...updates }; - // But don't call updateAuthState yet - }); + const oauthMachine = new OAuthStateMachine( + serverUrl, + (updates) => { + // Update our temporary state during the process + currentState = { ...currentState, ...updates }; + // But don't call updateAuthState yet + }, + fetchFn, + ); // Manually step through each stage of the OAuth flow while (currentState.oauthStep !== "complete") { @@ -214,7 +227,7 @@ const AuthDebugger = ({ } finally { updateAuthState({ isInitiatingAuth: false }); } - }, [serverUrl, updateAuthState, authState]); + }, [serverUrl, updateAuthState, authState, fetchFn]); const handleClearOAuth = useCallback(() => { if (serverUrl) { diff --git a/client/src/components/__tests__/AuthDebugger.test.tsx b/client/src/components/__tests__/AuthDebugger.test.tsx index 5d5042ea5..a8723476a 100644 --- a/client/src/components/__tests__/AuthDebugger.test.tsx +++ b/client/src/components/__tests__/AuthDebugger.test.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { render, screen, @@ -8,8 +9,8 @@ import { import "@testing-library/jest-dom"; import { describe, it, beforeEach, jest } from "@jest/globals"; import AuthDebugger, { AuthDebuggerProps } from "../AuthDebugger"; -import { TooltipProvider } from "@/components/ui/tooltip"; -import { SESSION_KEYS } from "@/lib/constants"; +import { TooltipProvider } from "../ui/tooltip"; +import { SESSION_KEYS, DEFAULT_INSPECTOR_CONFIG } from "../../lib/constants"; const mockOAuthTokens = { access_token: "test_access_token", @@ -55,10 +56,10 @@ import { discoverOAuthProtectedResourceMetadata, } from "@modelcontextprotocol/sdk/client/auth.js"; import { OAuthMetadata } from "@modelcontextprotocol/sdk/shared/auth.js"; -import { EMPTY_DEBUGGER_STATE } from "@/lib/auth-types"; +import { EMPTY_DEBUGGER_STATE } from "../../lib/auth-types"; // Mock local auth module -jest.mock("@/lib/auth", () => ({ +jest.mock("../../lib/auth", () => ({ DebugInspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({ tokens: jest.fn().mockImplementation(() => Promise.resolve(undefined)), clear: jest.fn().mockImplementation(() => { @@ -106,7 +107,7 @@ jest.mock("@/lib/auth", () => ({ discoverScopes: jest.fn().mockResolvedValue("read write" as never), })); -import { discoverScopes } from "@/lib/auth"; +import { discoverScopes } from "../../lib/auth"; // Type the mocked functions properly const mockDiscoverAuthorizationServerMetadata = @@ -269,6 +270,7 @@ describe("AuthDebugger", () => { // Should first discover and save OAuth metadata expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://example.com/"), + { fetchFn: undefined }, ); // Check that updateAuthState was called with the right info message @@ -404,6 +406,65 @@ describe("AuthDebugger", () => { }); }); + describe("Proxy Fetch integration", () => { + it("passes fetchFn to SDK when connectionType is proxy", async () => { + const configWithProxy = { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: "http://localhost:6277", + }, + MCP_PROXY_AUTH_TOKEN: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_AUTH_TOKEN, + value: "test-proxy-token", + }, + }; + + await act(async () => { + renderAuthDebugger({ + config: configWithProxy, + connectionType: "proxy", + authState: { + ...defaultAuthState, + isInitiatingAuth: false, + oauthStep: "metadata_discovery", + }, + }); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Continue")); + }); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("https://example.com/"), + { fetchFn: expect.any(Function) }, + ); + }); + + it("passes undefined fetchFn when connectionType is direct", async () => { + await act(async () => { + renderAuthDebugger({ + connectionType: "direct", + authState: { + ...defaultAuthState, + isInitiatingAuth: false, + oauthStep: "metadata_discovery", + }, + }); + }); + + await act(async () => { + fireEvent.click(screen.getByText("Continue")); + }); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( + new URL("https://example.com/"), + { fetchFn: undefined }, + ); + }); + }); + describe("OAuth Flow Steps", () => { it("should handle OAuth flow step progression", async () => { const updateAuthState = jest.fn(); @@ -428,6 +489,7 @@ describe("AuthDebugger", () => { expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://example.com/"), + { fetchFn: undefined }, ); }); @@ -725,6 +787,8 @@ describe("AuthDebugger", () => { await waitFor(() => { expect(mockDiscoverOAuthProtectedResourceMetadata).toHaveBeenCalledWith( "https://example.com/mcp", + {}, + undefined, ); }); @@ -773,6 +837,8 @@ describe("AuthDebugger", () => { await waitFor(() => { expect(mockDiscoverOAuthProtectedResourceMetadata).toHaveBeenCalledWith( "https://example.com/mcp", + {}, + undefined, ); }); @@ -791,6 +857,7 @@ describe("AuthDebugger", () => { // Verify that regular OAuth metadata discovery was still called expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledWith( new URL("https://example.com/"), + { fetchFn: undefined }, ); }); }); diff --git a/client/src/lib/__tests__/auth.test.ts b/client/src/lib/__tests__/auth.test.ts index 329b7f027..03c503d81 100644 --- a/client/src/lib/__tests__/auth.test.ts +++ b/client/src/lib/__tests__/auth.test.ts @@ -133,7 +133,10 @@ describe("discoverScopes", () => { expect(result).toBe(expected); if (expectedCallUrl) { - expect(mockDiscoverAuth).toHaveBeenCalledWith(new URL(expectedCallUrl)); + expect(mockDiscoverAuth).toHaveBeenCalledWith( + new URL(expectedCallUrl), + { fetchFn: undefined }, + ); } }, ); diff --git a/client/src/lib/__tests__/proxyFetch.test.ts b/client/src/lib/__tests__/proxyFetch.test.ts new file mode 100644 index 000000000..c006039f1 --- /dev/null +++ b/client/src/lib/__tests__/proxyFetch.test.ts @@ -0,0 +1,146 @@ +import { createProxyFetch } from "../proxyFetch"; +import { DEFAULT_INSPECTOR_CONFIG } from "../constants"; +import type { InspectorConfig } from "../configurationTypes"; + +describe("createProxyFetch", () => { + const mockFetch = jest.fn(); + const proxyAddress = "http://localhost:6277"; + + const configWithProxy: InspectorConfig = { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: proxyAddress, + }, + MCP_PROXY_AUTH_TOKEN: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_AUTH_TOKEN, + value: "test-proxy-token", + }, + }; + + beforeEach(() => { + jest.clearAllMocks(); + global.fetch = mockFetch; + }); + + it("returns a function", () => { + const fetchFn = createProxyFetch(configWithProxy); + expect(typeof fetchFn).toBe("function"); + }); + + it("sends POST to proxy /fetch endpoint with correct headers", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: {}, + body: "response body", + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn("https://example.com/.well-known/oauth-authorization-server"); + + expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledWith(`${proxyAddress}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-MCP-Proxy-Auth": "Bearer test-proxy-token", + }, + body: expect.any(String), + }); + }); + + it("includes target url and init in request body", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: { "content-type": "application/json" }, + body: '{"issuer":"https://example.com"}', + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn("https://example.com/oauth/token", { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: "grant_type=authorization_code&code=abc", + }); + + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody).toEqual({ + url: "https://example.com/oauth/token", + init: { + method: "POST", + headers: { "content-type": "application/x-www-form-urlencoded" }, + body: "grant_type=authorization_code&code=abc", + }, + }); + }); + + it("reconstructs Response from proxy response", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: { "content-type": "application/json" }, + body: '{"issuer":"https://example.com"}', + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + const response = await fetchFn( + "https://example.com/.well-known/oauth-authorization-server", + ); + + expect(response.ok).toBe(true); + expect(response.status).toBe(200); + expect(response.statusText).toBe("OK"); + expect(response.headers.get("content-type")).toBe("application/json"); + const body = await response.text(); + expect(body).toBe('{"issuer":"https://example.com"}'); + }); + + it("throws when proxy returns non-ok response", async () => { + mockFetch.mockResolvedValue({ + ok: false, + statusText: "Unauthorized", + }); + + const fetchFn = createProxyFetch(configWithProxy); + await expect(fetchFn("https://example.com/")).rejects.toThrow( + "Proxy fetch failed: Unauthorized", + ); + }); + + it("uses URL object as input", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + ok: true, + status: 200, + statusText: "OK", + headers: {}, + body: "", + }), + }); + + const fetchFn = createProxyFetch(configWithProxy); + await fetchFn(new URL("https://example.com/discovery")); + + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody.url).toBe("https://example.com/discovery"); + }); +}); diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 879936104..0aafa62d9 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -22,10 +22,12 @@ import { validateRedirectUrl } from "@/utils/urlValidation"; export const discoverScopes = async ( serverUrl: string, resourceMetadata?: OAuthProtectedResourceMetadata, + fetchFn?: typeof fetch, ): Promise => { try { const metadata = await discoverAuthorizationServerMetadata( new URL("/", serverUrl), + { fetchFn }, ); // Prefer resource metadata scopes, but fall back to OAuth metadata if empty diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index 4907a085b..1e5291a0f 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -322,7 +322,7 @@ describe("useConnection", () => { const [, samplingHandler] = samplingHandlerCall; // Invoke handler; should return a CreateTaskResult immediately - let createTaskResult: SchemaOutput; + let createTaskResult!: SchemaOutput; await act(async () => { createTaskResult = await samplingHandler(samplingRequest); }); @@ -449,7 +449,7 @@ describe("useConnection", () => { }); expect(elicitRequestHandlerCall).toBeDefined(); - const [, handler] = elicitRequestHandlerCall; + const [, handler] = elicitRequestHandlerCall!; mockOnElicitationRequest.mockImplementation((_request, resolve) => { resolve({ action: "accept", content: { name: "test" } }); @@ -640,7 +640,7 @@ describe("useConnection", () => { }); expect(elicitRequestHandlerCall).toBeDefined(); - const [, handler] = elicitRequestHandlerCall; + const [, handler] = elicitRequestHandlerCall!; const mockElicitationRequest: ElicitRequest = { method: "elicitation/create", @@ -707,7 +707,8 @@ describe("useConnection", () => { } }); - const [, handler] = elicitRequestHandlerCall; + expect(elicitRequestHandlerCall).toBeDefined(); + const [, handler] = elicitRequestHandlerCall!; const mockElicitationRequest: ElicitRequest = { method: "elicitation/create", @@ -732,7 +733,7 @@ describe("useConnection", () => { resolve(mockResponse); }); - let handlerResult; + let handlerResult!: ElicitResult; await act(async () => { handlerResult = await handler(mockElicitationRequest); }); @@ -1514,15 +1515,19 @@ describe("useConnection", () => { expect(mockDiscoverScopes).toHaveBeenCalledWith( defaultProps.sseUrl, undefined, + expect.any(Function), // fetchFn when connectionType is proxy ); } else { expect(mockDiscoverScopes).not.toHaveBeenCalled(); } - expect(mockAuth).toHaveBeenCalledWith(expect.any(Object), { - serverUrl: defaultProps.sseUrl, - scope: expectedAuthScope, - }); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: defaultProps.sseUrl, + scope: expectedAuthScope, + }), + ); }, ); @@ -1538,11 +1543,36 @@ describe("useConnection", () => { expect(mockDiscoverScopes).toHaveBeenCalledWith( defaultProps.sseUrl, undefined, + expect.any(Function), // fetchFn when connectionType is proxy + ); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + serverUrl: defaultProps.sseUrl, + scope: undefined, + }), + ); + }); + + it("passes undefined fetchFn when connectionType is direct", async () => { + mockDiscoverScopes.mockResolvedValue("read write"); + setup401Error(); + + const directProps = { + ...defaultProps, + connectionType: "direct" as const, + }; + await attemptConnection(directProps); + + expect(mockDiscoverScopes).toHaveBeenCalledWith( + defaultProps.sseUrl, + undefined, + undefined, // fetchFn is undefined for direct + ); + expect(mockAuth).toHaveBeenCalledWith( + expect.any(Object), + expect.not.objectContaining({ fetchFn: expect.anything() }), ); - expect(mockAuth).toHaveBeenCalledWith(expect.any(Object), { - serverUrl: defaultProps.sseUrl, - scope: undefined, - }); }); }); diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index e14d1037f..3e8e9a19d 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -64,6 +64,7 @@ import { clearScopeFromSessionStorage, discoverScopes, } from "../auth"; +import { createProxyFetch } from "../proxyFetch"; import { getMCPProxyAddress, getMCPTaskTtl, @@ -400,17 +401,22 @@ export function useConnection({ const handleAuthError = async (error: unknown) => { if (is401Error(error)) { let scope = oauthScope?.trim(); + const fetchFn = + connectionType === "proxy" ? createProxyFetch(config) : undefined; + if (!scope) { // Only discover resource metadata when we need to discover scopes let resourceMetadata; try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( new URL("/", sseUrl), + {}, + fetchFn, ); } catch { // Resource metadata is optional, continue without it } - scope = await discoverScopes(sseUrl, resourceMetadata); + scope = await discoverScopes(sseUrl, resourceMetadata, fetchFn); } saveScopeToSessionStorage(sseUrl, scope); @@ -420,6 +426,7 @@ export function useConnection({ const result = await auth(serverAuthProvider, { serverUrl: sseUrl, scope, + ...(fetchFn && { fetchFn }), }); return result === "AUTHORIZED"; } catch (authError) { diff --git a/client/src/lib/oauth-state-machine.ts b/client/src/lib/oauth-state-machine.ts index 8dc9da8f9..6628b9ad5 100644 --- a/client/src/lib/oauth-state-machine.ts +++ b/client/src/lib/oauth-state-machine.ts @@ -19,6 +19,7 @@ export interface StateMachineContext { serverUrl: string; provider: DebugInspectorOAuthClientProvider; updateState: (updates: Partial) => void; + fetchFn?: typeof fetch; } export interface StateTransition { @@ -38,6 +39,8 @@ export const oauthTransitions: Record = { try { resourceMetadata = await discoverOAuthProtectedResourceMetadata( context.serverUrl, + {}, + context.fetchFn, ); if (resourceMetadata?.authorization_servers?.length) { authServerUrl = new URL(resourceMetadata.authorization_servers[0]); @@ -57,7 +60,10 @@ export const oauthTransitions: Record = { resourceMetadata ?? undefined, ); - const metadata = await discoverAuthorizationServerMetadata(authServerUrl); + const metadata = await discoverAuthorizationServerMetadata( + authServerUrl, + { fetchFn: context.fetchFn }, + ); if (!metadata) { throw new Error("Failed to discover OAuth metadata"); } @@ -98,6 +104,7 @@ export const oauthTransitions: Record = { fullInformation = await registerClient(context.serverUrl, { metadata, clientMetadata, + fetchFn: context.fetchFn, }); context.provider.saveClientInformation(fullInformation); } @@ -122,6 +129,7 @@ export const oauthTransitions: Record = { scope = await discoverScopes( context.serverUrl, context.state.resourceMetadata ?? undefined, + context.fetchFn, ); } @@ -189,6 +197,7 @@ export const oauthTransitions: Record = { ? context.state.resource : new URL(context.state.resource) : undefined, + fetchFn: context.fetchFn, }); context.provider.saveTokens(tokens); @@ -211,6 +220,7 @@ export class OAuthStateMachine { constructor( private serverUrl: string, private updateState: (updates: Partial) => void, + private fetchFn?: typeof fetch, ) {} async executeStep(state: AuthDebuggerState): Promise { @@ -220,6 +230,7 @@ export class OAuthStateMachine { serverUrl: this.serverUrl, provider, updateState: this.updateState, + fetchFn: this.fetchFn, }; const transition = oauthTransitions[state.oauthStep]; diff --git a/client/src/lib/proxyFetch.ts b/client/src/lib/proxyFetch.ts new file mode 100644 index 000000000..069ebcb07 --- /dev/null +++ b/client/src/lib/proxyFetch.ts @@ -0,0 +1,69 @@ +import { getMCPProxyAddress, getMCPProxyAuthToken } from "@/utils/configUtils"; +import type { InspectorConfig } from "./configurationTypes"; + +interface ProxyFetchResponse { + ok: boolean; + status: number; + statusText: string; + headers: Record; + body: string; +} + +/** + * Creates a fetch function that routes requests through the proxy server + * to avoid CORS restrictions on OAuth discovery and token endpoints. + */ +export function createProxyFetch(config: InspectorConfig): typeof fetch { + const proxyAddress = getMCPProxyAddress(config); + const { token, header } = getMCPProxyAuthToken(config); + + return async ( + input: RequestInfo | URL, + init?: RequestInit, + ): Promise => { + const url = typeof input === "string" ? input : input.toString(); + + // Serialize body for JSON transport. URLSearchParams and similar don't + // JSON-serialize (they become {}), so we must convert to string first. + let serializedBody: string | undefined; + if (init?.body != null) { + if (typeof init.body === "string") { + serializedBody = init.body; + } else if (init.body instanceof URLSearchParams) { + serializedBody = init.body.toString(); + } else { + serializedBody = String(init.body); + } + } + + const proxyResponse = await fetch(`${proxyAddress}/fetch`, { + method: "POST", + headers: { + "Content-Type": "application/json", + [header]: `Bearer ${token}`, + }, + body: JSON.stringify({ + url, + init: { + method: init?.method, + headers: init?.headers + ? Object.fromEntries(new Headers(init.headers)) + : undefined, + body: serializedBody, + }, + }), + }); + + if (!proxyResponse.ok) { + throw new Error(`Proxy fetch failed: ${proxyResponse.statusText}`); + } + + const data: ProxyFetchResponse = await proxyResponse.json(); + + return new Response(data.body, { + status: data.status, + statusText: data.statusText, + headers: new Headers(data.headers), + }); + }; +} diff --git a/server/src/index.ts b/server/src/index.ts index 388fdaca7..78db68ae8 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -772,6 +772,42 @@ app.get("/health", (req, res) => { }); }); +app.post( + "/fetch", + express.json(), + originValidationMiddleware, + authMiddleware, + async (req, res) => { + try { + const { url, init } = req.body as { url: string; init?: RequestInit }; + + const response = await fetch(url, { + method: init?.method ?? "GET", + headers: (init?.headers as Record) ?? {}, + body: init?.body as string | undefined, + }); + + const responseBody = await response.text(); + const headers: Record = {}; + response.headers.forEach((value, key) => { + headers[key] = value; + }); + + res.status(response.status).json({ + ok: response.ok, + status: response.status, + statusText: response.statusText, + headers, + body: responseBody, + }); + } catch (error) { + res.status(500).json({ + error: error instanceof Error ? error.message : String(error), + }); + } + }, +); + app.get("/config", originValidationMiddleware, authMiddleware, (req, res) => { try { res.json({