Skip to content

Commit b39ed3c

Browse files
committed
Add tests for proxy builder and OAuth endpoints
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent cdbdeb3 commit b39ed3c

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed

examples/servers/proxy-auth/tests/__init__.py

Whitespace-only changes.
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# pyright: reportMissingImports=false
2+
# pytest test suite for proxy_auth/combo_server.py
3+
# These tests spin up the FastMCP Starlette application in-process and
4+
# exercise the custom HTTP routes as well as the `user_info` tool.
5+
6+
from __future__ import annotations
7+
8+
import base64
9+
import json
10+
import sys
11+
import os
12+
import urllib.parse
13+
from collections.abc import AsyncGenerator
14+
from typing import Any
15+
16+
import httpx # type: ignore
17+
import pytest # type: ignore
18+
19+
20+
@pytest.fixture
21+
def proxy_server(monkeypatch):
22+
"""Import the proxy OAuth demo server with safe environment + stubs."""
23+
import os
24+
25+
# Avoid real outbound calls by pretending the upstream endpoints were
26+
# supplied explicitly via env vars – this makes `fetch_upstream_metadata`
27+
# construct metadata locally instead of performing an HTTP GET.
28+
os.environ.setdefault("UPSTREAM_AUTHORIZATION_ENDPOINT", "https://upstream.example.com/authorize")
29+
os.environ.setdefault("UPSTREAM_TOKEN_ENDPOINT", "https://upstream.example.com/token")
30+
os.environ.setdefault("UPSTREAM_JWKS_URI", "https://upstream.example.com/jwks")
31+
os.environ.setdefault("UPSTREAM_CLIENT_ID", "client123")
32+
os.environ.setdefault("UPSTREAM_CLIENT_SECRET", "secret123")
33+
34+
# Deferred import so the env vars above are in effect.
35+
from proxy_auth import combo_server as proxy_server_module
36+
37+
# Stub library-level fetch_upstream_metadata to avoid network I/O.
38+
from mcp.server.auth.proxy import routes as proxy_routes
39+
40+
<<<<<<< Updated upstream
41+
=======
42+
# Handle imports whether running from root or project directory
43+
try:
44+
# Try direct import first (when running from project directory)
45+
from proxy_auth import combo_server as proxy_server_module
46+
except ImportError:
47+
# If that fails, try to adjust path for running from root directory
48+
current_dir = os.path.dirname(os.path.abspath(__file__))
49+
project_dir = os.path.dirname(current_dir)
50+
if project_dir not in sys.path:
51+
sys.path.insert(0, project_dir)
52+
from proxy_auth import combo_server as proxy_server_module
53+
54+
>>>>>>> Stashed changes
55+
async def _fake_metadata() -> dict[str, Any]: # noqa: D401
56+
return {
57+
"issuer": proxy_server_module.UPSTREAM_BASE,
58+
"authorization_endpoint": proxy_server_module.UPSTREAM_AUTHORIZE,
59+
"token_endpoint": proxy_server_module.UPSTREAM_TOKEN,
60+
"registration_endpoint": "/register",
61+
"jwks_uri": "",
62+
}
63+
64+
monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True)
65+
return proxy_server_module
66+
67+
68+
@pytest.fixture
69+
def app(proxy_server):
70+
"""Return the Starlette ASGI app for tests."""
71+
return proxy_server.mcp.streamable_http_app()
72+
73+
74+
@pytest.fixture
75+
async def client(app) -> AsyncGenerator[httpx.AsyncClient, None]:
76+
"""Async HTTP client bound to the in-memory ASGI application."""
77+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as c:
78+
yield c
79+
80+
81+
# ---------------------------------------------------------------------------
82+
# HTTP endpoint tests
83+
# ---------------------------------------------------------------------------
84+
85+
86+
@pytest.mark.anyio
87+
async def test_metadata_endpoint(client):
88+
r = await client.get("/.well-known/oauth-authorization-server")
89+
assert r.status_code == 200
90+
data = r.json()
91+
assert "issuer" in data
92+
assert data["authorization_endpoint"].endswith("/authorize")
93+
assert data["token_endpoint"].endswith("/token")
94+
assert data["registration_endpoint"].endswith("/register")
95+
96+
97+
@pytest.mark.anyio
98+
async def test_registration_endpoint(client, proxy_server):
99+
payload = {"redirect_uris": ["https://client.example.com/callback"]}
100+
r = await client.post("/register", json=payload)
101+
assert r.status_code == 201
102+
body = r.json()
103+
assert body["client_id"] == proxy_server.CLIENT_ID
104+
assert body["redirect_uris"] == payload["redirect_uris"]
105+
# client_secret may be None, but the field should exist (masked or real)
106+
assert "client_secret" in body
107+
108+
109+
@pytest.mark.anyio
110+
async def test_authorize_redirect(client, proxy_server):
111+
params = {
112+
"response_type": "code",
113+
"state": "xyz",
114+
"redirect_uri": "https://client.example.com/callback",
115+
"client_id": proxy_server.CLIENT_ID,
116+
"code_challenge": "testchallenge",
117+
"code_challenge_method": "S256",
118+
}
119+
r = await client.get("/authorize", params=params, follow_redirects=False)
120+
assert r.status_code in {302, 307}
121+
122+
location = r.headers["location"]
123+
parsed = urllib.parse.urlparse(location)
124+
assert parsed.scheme.startswith("http")
125+
assert parsed.netloc == urllib.parse.urlparse(proxy_server.UPSTREAM_AUTHORIZE).netloc
126+
127+
qs = urllib.parse.parse_qs(parsed.query)
128+
# Proxy should inject client_id & default scope
129+
assert qs["client_id"][0] == proxy_server.CLIENT_ID
130+
assert "scope" in qs
131+
# Original params preserved
132+
assert qs["state"][0] == "xyz"
133+
134+
135+
@pytest.mark.anyio
136+
async def test_revoke_proxy(client, monkeypatch, proxy_server):
137+
original_post = httpx.AsyncClient.post
138+
139+
async def _mock_post(self, url, data=None, timeout=10, **kwargs): # noqa: D401
140+
if url.endswith("/revoke"):
141+
return httpx.Response(200, json={"revoked": True})
142+
# For the test client's own request to /revoke, delegate to original implementation
143+
return await original_post(self, url, data=data, timeout=timeout, **kwargs)
144+
145+
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True)
146+
147+
r = await client.post("/revoke", data={"token": "dummy"})
148+
assert r.status_code == 200
149+
assert r.json() == {"revoked": True}
150+
151+
152+
@pytest.mark.anyio
153+
async def test_token_passthrough(client, monkeypatch, proxy_server):
154+
"""Ensure /token is proxied unchanged and response is returned verbatim."""
155+
156+
# Capture outgoing POSTs made by ProxyTokenHandler
157+
captured: dict[str, Any] = {}
158+
159+
original_post = httpx.AsyncClient.post
160+
161+
async def _mock_post(self, url, *args, **kwargs): # noqa: D401
162+
if str(url).startswith(proxy_server.UPSTREAM_TOKEN):
163+
# Record exactly what was sent upstream
164+
captured["url"] = str(url)
165+
captured["data"] = kwargs.get("data")
166+
# Return a dummy upstream response
167+
return httpx.Response(
168+
200,
169+
json={
170+
"access_token": "xyz",
171+
"token_type": "bearer",
172+
"expires_in": 3600,
173+
},
174+
)
175+
# Delegate any other POSTs to the real implementation
176+
return await original_post(self, url, *args, **kwargs)
177+
178+
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True)
179+
180+
# ---------------- Act ----------------
181+
form = {
182+
"grant_type": "authorization_code",
183+
"code": "dummy-code",
184+
"client_id": proxy_server.CLIENT_ID,
185+
}
186+
r = await client.post("/token", data=form)
187+
188+
# ---------------- Assert -------------
189+
assert r.status_code == 200
190+
assert r.json()["access_token"] == "xyz"
191+
192+
# Verify the request payload was forwarded without modification
193+
assert captured["data"] == form
194+
195+
196+
# ---------------------------------------------------------------------------
197+
# Tool invocation – user_info
198+
# ---------------------------------------------------------------------------
199+
200+
201+
@pytest.mark.anyio
202+
async def test_user_info_tool(monkeypatch, proxy_server):
203+
"""Call the `user_info` tool directly with a mocked access token."""
204+
# Craft a dummy JWT with useful claims (header/payload/signature parts)
205+
payload = (
206+
base64.urlsafe_b64encode(
207+
json.dumps(
208+
{
209+
"sub": "test-user",
210+
"preferred_username": "tester",
211+
}
212+
).encode()
213+
)
214+
.decode()
215+
.rstrip("=")
216+
)
217+
dummy_token = f"header.{payload}.signature"
218+
219+
from mcp.server.auth.middleware import auth_context
220+
from mcp.server.auth.provider import AccessToken # local import to avoid cycles
221+
222+
def _fake_get_access_token(): # noqa: D401
223+
return AccessToken(token=dummy_token, client_id="client123", scopes=["openid"], expires_at=None)
224+
225+
monkeypatch.setattr(auth_context, "get_access_token", _fake_get_access_token, raising=True)
226+
227+
result = await proxy_server.mcp.call_tool("user_info", {})
228+
229+
# call_tool returns (content_blocks, raw_result)
230+
if isinstance(result, tuple):
231+
_, raw = result
232+
else:
233+
raw = result # fallback
234+
235+
assert raw["authenticated"] is True
236+
assert ("userid" in raw and raw["userid"] == "test-user") or ("user_id" in raw and raw["user_id"] == "test-user")
237+
assert raw["username"] == "tester"

tests/test_proxy_builder.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# pyright: reportMissingImports=false, reportGeneralTypeIssues=false
2+
"""Tests for the build_proxy_server convenience helper."""
3+
4+
from __future__ import annotations
5+
6+
from typing import cast
7+
8+
import httpx # type: ignore
9+
import pytest # type: ignore
10+
from pydantic import AnyHttpUrl
11+
12+
from mcp.server.auth.providers.transparent_proxy import _Settings as ProxySettings
13+
from mcp.server.auth.proxy import routes as proxy_routes
14+
from mcp.server.auth.proxy.server import build_proxy_server
15+
16+
17+
@pytest.mark.anyio
18+
async def test_build_proxy_server_metadata(monkeypatch):
19+
"""Ensure the server starts and serves metadata without touching network."""
20+
21+
# Patch metadata fetcher so no real HTTP traffic occurs
22+
async def _fake_metadata(): # noqa: D401
23+
return {
24+
"issuer": "https://proxy.test",
25+
"authorization_endpoint": "https://proxy.test/authorize",
26+
"token_endpoint": "https://proxy.test/token",
27+
"registration_endpoint": "/register",
28+
}
29+
30+
monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True)
31+
32+
# Provide required upstream endpoints via settings object
33+
settings = ProxySettings( # type: ignore[call-arg]
34+
UPSTREAM_AUTHORIZATION_ENDPOINT=cast(AnyHttpUrl, "https://upstream.example.com/authorize"),
35+
UPSTREAM_TOKEN_ENDPOINT=cast(AnyHttpUrl, "https://upstream.example.com/token"),
36+
UPSTREAM_CLIENT_ID="demo-client-id",
37+
UPSTREAM_CLIENT_SECRET=None,
38+
UPSTREAM_JWKS_URI=None,
39+
)
40+
41+
mcp = build_proxy_server(port=0, settings=settings)
42+
43+
app = mcp.streamable_http_app()
44+
45+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as c:
46+
r = await c.get("/.well-known/oauth-authorization-server")
47+
assert r.status_code == 200
48+
data = r.json()
49+
assert data["authorization_endpoint"].endswith("/authorize")

0 commit comments

Comments
 (0)