1+ # pyright: reportMissingImports=false
2+ # pytest test suite for proxy_auth/combo_server.py
3+ # These tests spin up the FastMCP Starlette application in-process and
4+ # exercise the custom HTTP routes as well as the `user_info` tool.
5+
6+ from __future__ import annotations
7+
8+ import base64
9+ import json
10+ import sys
11+ import os
12+ import urllib .parse
13+ from collections .abc import AsyncGenerator
14+ from typing import Any
15+
16+ import httpx # type: ignore
17+ import pytest # type: ignore
18+
19+
20+ @pytest .fixture
21+ def proxy_server (monkeypatch ):
22+ """Import the proxy OAuth demo server with safe environment + stubs."""
23+ import os
24+
25+ # Avoid real outbound calls by pretending the upstream endpoints were
26+ # supplied explicitly via env vars – this makes `fetch_upstream_metadata`
27+ # construct metadata locally instead of performing an HTTP GET.
28+ os .environ .setdefault ("UPSTREAM_AUTHORIZATION_ENDPOINT" , "https://upstream.example.com/authorize" )
29+ os .environ .setdefault ("UPSTREAM_TOKEN_ENDPOINT" , "https://upstream.example.com/token" )
30+ os .environ .setdefault ("UPSTREAM_JWKS_URI" , "https://upstream.example.com/jwks" )
31+ os .environ .setdefault ("UPSTREAM_CLIENT_ID" , "client123" )
32+ os .environ .setdefault ("UPSTREAM_CLIENT_SECRET" , "secret123" )
33+
34+ # Deferred import so the env vars above are in effect.
35+ from proxy_auth import combo_server as proxy_server_module
36+
37+ # Stub library-level fetch_upstream_metadata to avoid network I/O.
38+ from mcp .server .auth .proxy import routes as proxy_routes
39+
40+ < << << << Updated upstream
41+ == == == =
42+ # Handle imports whether running from root or project directory
43+ try :
44+ # Try direct import first (when running from project directory)
45+ from proxy_auth import combo_server as proxy_server_module
46+ except ImportError :
47+ # If that fails, try to adjust path for running from root directory
48+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
49+ project_dir = os .path .dirname (current_dir )
50+ if project_dir not in sys .path :
51+ sys .path .insert (0 , project_dir )
52+ from proxy_auth import combo_server as proxy_server_module
53+
54+ >> >> >> > Stashed changes
55+ async def _fake_metadata () -> dict [str , Any ]: # noqa: D401
56+ return {
57+ "issuer ": proxy_server_module .UPSTREAM_BASE ,
58+ "authorization_endpoint ": proxy_server_module .UPSTREAM_AUTHORIZE ,
59+ "token_endpoint ": proxy_server_module .UPSTREAM_TOKEN ,
60+ "registration_endpoint ": "/ register ",
61+ "jwks_uri ": "",
62+ }
63+
64+ monkeypatch .setattr (proxy_routes , "fetch_upstream_metadata ", _fake_metadata , raising = True )
65+ return proxy_server_module
66+
67+
68+ @pytest .fixture
69+ def app (proxy_server ):
70+ """Return the Starlette ASGI app for tests ."""
71+ return proxy_server .mcp .streamable_http_app ()
72+
73+
74+ @pytest .fixture
75+ async def client (app ) -> AsyncGenerator [httpx .AsyncClient , None ]:
76+ """Async HTTP client bound to the in - memory ASGI application ."""
77+ async with httpx .AsyncClient (transport = httpx .ASGITransport (app = app ), base_url = "http :// testserver ") as c :
78+ yield c
79+
80+
81+ # ---------------------------------------------------------------------------
82+ # HTTP endpoint tests
83+ # ---------------------------------------------------------------------------
84+
85+
86+ @pytest .mark .anyio
87+ async def test_metadata_endpoint (client ):
88+ r = await client .get ("/ .well - known / oauth - authorization - server ")
89+ assert r .status_code == 200
90+ data = r .json ()
91+ assert "issuer " in data
92+ assert data ["authorization_endpoint "].endswith ("/ authorize ")
93+ assert data ["token_endpoint "].endswith ("/ token ")
94+ assert data ["registration_endpoint "].endswith ("/ register ")
95+
96+
97+ @pytest .mark .anyio
98+ async def test_registration_endpoint (client , proxy_server ):
99+ payload = {"redirect_uris ": ["https :// client .example .com / callback "]}
100+ r = await client .post ("/ register ", json = payload )
101+ assert r .status_code == 201
102+ body = r .json ()
103+ assert body ["client_id "] == proxy_server .CLIENT_ID
104+ assert body ["redirect_uris "] == payload ["redirect_uris "]
105+ # client_secret may be None, but the field should exist (masked or real)
106+ assert "client_secret " in body
107+
108+
109+ @pytest .mark .anyio
110+ async def test_authorize_redirect (client , proxy_server ):
111+ params = {
112+ "response_type ": "code ",
113+ "state ": "xyz ",
114+ "redirect_uri ": "https :// client .example .com / callback ",
115+ "client_id ": proxy_server .CLIENT_ID ,
116+ "code_challenge ": "testchallenge ",
117+ "code_challenge_method ": "S256 ",
118+ }
119+ r = await client .get ("/ authorize ", params = params , follow_redirects = False )
120+ assert r .status_code in {302 , 307 }
121+
122+ location = r .headers ["location "]
123+ parsed = urllib .parse .urlparse (location )
124+ assert parsed .scheme .startswith ("http ")
125+ assert parsed .netloc == urllib .parse .urlparse (proxy_server .UPSTREAM_AUTHORIZE ).netloc
126+
127+ qs = urllib .parse .parse_qs (parsed .query )
128+ # Proxy should inject client_id & default scope
129+ assert qs ["client_id "][0 ] == proxy_server .CLIENT_ID
130+ assert "scope " in qs
131+ # Original params preserved
132+ assert qs ["state "][0 ] == "xyz "
133+
134+
135+ @pytest .mark .anyio
136+ async def test_revoke_proxy (client , monkeypatch , proxy_server ):
137+ original_post = httpx .AsyncClient .post
138+
139+ async def _mock_post (self , url , data = None , timeout = 10 , ** kwargs ): # noqa: D401
140+ if url .endswith ("/ revoke "):
141+ return httpx .Response (200 , json = {"revoked ": True })
142+ # For the test client's own request to /revoke, delegate to original implementation
143+ return await original_post (self , url , data = data , timeout = timeout , ** kwargs )
144+
145+ monkeypatch .setattr (httpx .AsyncClient , "post ", _mock_post , raising = True )
146+
147+ r = await client .post ("/ revoke ", data = {"token ": "dummy "})
148+ assert r .status_code == 200
149+ assert r .json () == {"revoked ": True }
150+
151+
152+ @pytest .mark .anyio
153+ async def test_token_passthrough (client , monkeypatch , proxy_server ):
154+ """Ensure / token is proxied unchanged and response is returned verbatim ."""
155+
156+ # Capture outgoing POSTs made by ProxyTokenHandler
157+ captured : dict [str , Any ] = {}
158+
159+ original_post = httpx .AsyncClient .post
160+
161+ async def _mock_post (self , url , * args , ** kwargs ): # noqa: D401
162+ if str (url ).startswith (proxy_server .UPSTREAM_TOKEN ):
163+ # Record exactly what was sent upstream
164+ captured ["url "] = str (url )
165+ captured ["data "] = kwargs .get ("data ")
166+ # Return a dummy upstream response
167+ return httpx .Response (
168+ 200 ,
169+ json = {
170+ "access_token ": "xyz ",
171+ "token_type ": "bearer ",
172+ "expires_in ": 3600 ,
173+ },
174+ )
175+ # Delegate any other POSTs to the real implementation
176+ return await original_post (self , url , * args , ** kwargs )
177+
178+ monkeypatch .setattr (httpx .AsyncClient , "post ", _mock_post , raising = True )
179+
180+ # ---------------- Act ----------------
181+ form = {
182+ "grant_type ": "authorization_code ",
183+ "code ": "dummy - code ",
184+ "client_id ": proxy_server .CLIENT_ID ,
185+ }
186+ r = await client .post ("/ token ", data = form )
187+
188+ # ---------------- Assert -------------
189+ assert r .status_code == 200
190+ assert r .json ()["access_token "] == "xyz "
191+
192+ # Verify the request payload was forwarded without modification
193+ assert captured ["data "] == form
194+
195+
196+ # ---------------------------------------------------------------------------
197+ # Tool invocation – user_info
198+ # ---------------------------------------------------------------------------
199+
200+
201+ @pytest .mark .anyio
202+ async def test_user_info_tool (monkeypatch , proxy_server ):
203+ """Call the `user_info ` tool directly with a mocked access token ."""
204+ # Craft a dummy JWT with useful claims (header/payload/signature parts)
205+ payload = (
206+ base64 .urlsafe_b64encode (
207+ json .dumps (
208+ {
209+ "sub ": "test - user ",
210+ "preferred_username ": "tester ",
211+ }
212+ ).encode ()
213+ )
214+ .decode ()
215+ .rstrip ("= ")
216+ )
217+ dummy_token = f"header .{payload }.signature "
218+
219+ from mcp .server .auth .middleware import auth_context
220+ from mcp .server .auth .provider import AccessToken # local import to avoid cycles
221+
222+ def _fake_get_access_token (): # noqa: D401
223+ return AccessToken (token = dummy_token , client_id = "client123 ", scopes = ["openid "], expires_at = None )
224+
225+ monkeypatch .setattr (auth_context , "get_access_token ", _fake_get_access_token , raising = True )
226+
227+ result = await proxy_server .mcp .call_tool ("user_info ", {})
228+
229+ # call_tool returns (content_blocks, raw_result)
230+ if isinstance (result , tuple ):
231+ _, raw = result
232+ else :
233+ raw = result # fallback
234+
235+ assert raw ["authenticated "] is True
236+ assert ("userid " in raw and raw ["userid "] == "test - user ") or ("user_id " in raw and raw ["user_id "] == "test - user ")
237+ assert raw ["username "] == "tester "
0 commit comments