Skip to content

Commit 58928e1

Browse files
authored
Merge pull request #28 from sacha-development-stuff/codex/fix-failing-test-for-oauth-token-response
Improve OAuth token response handling
2 parents d9e0243 + a9a64d0 commit 58928e1

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
222222
self.client_metadata.scope = " ".join(metadata.scopes_supported)
223223

224224
def _create_registration_request(self, metadata: OAuthMetadata | None = None) -> httpx.Request | None:
225-
if self._client_info:
225+
if self._client_info or self.context.client_info:
226226
return None
227227
if metadata and metadata.registration_endpoint:
228228
registration_url = str(metadata.registration_endpoint)
@@ -534,15 +534,27 @@ async def _exchange_token_authorization_code(
534534

535535
return httpx.Request("POST", token_url, data=token_data, headers=headers)
536536

537+
async def _read_response_content(self, response: httpx.Response) -> bytes:
538+
"""Read response content, handling preloaded or streaming bodies."""
539+
try:
540+
content = response.content
541+
if content:
542+
return content
543+
except RuntimeError:
544+
# Streaming response that hasn't been consumed yet - fall back to async read.
545+
pass
546+
547+
return await response.aread()
548+
537549
async def _handle_token_response(self, response: httpx.Response) -> None:
538550
"""Handle token exchange response."""
539551
if response.status_code != 200: # pragma: no cover
540-
body = await response.aread()
552+
body = await self._read_response_content(response)
541553
body = body.decode("utf-8")
542554
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
543555

544556
try:
545-
content = await response.aread()
557+
content = await self._read_response_content(response)
546558
token_response = OAuthToken.model_validate_json(content)
547559

548560
# Validate scopes
@@ -597,7 +609,7 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p
597609
return False
598610

599611
try:
600-
content = await response.aread()
612+
content = await self._read_response_content(response)
601613
token_response = OAuthToken.model_validate_json(content)
602614

603615
self.context.current_tokens = token_response
@@ -614,6 +626,8 @@ async def _initialize(self) -> None: # pragma: no cover
614626
"""Load stored tokens and client info."""
615627
self.context.current_tokens = await self.context.storage.get_tokens()
616628
self.context.client_info = await self.context.storage.get_client_info()
629+
if self.context.client_info:
630+
self._client_info = self.context.client_info
617631
self._initialized = True
618632

619633
def _add_auth_header(self, request: httpx.Request) -> None:
@@ -694,7 +708,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
694708
self.context.client_info = self._client_info
695709

696710
# Step 5: Perform authorization and complete token exchange
697-
token_response = yield await self._perform_authorization()
711+
auth_result = await self._perform_authorization()
712+
if isinstance(auth_result, httpx.Request):
713+
token_request = auth_result
714+
else:
715+
auth_code, code_verifier = auth_result
716+
token_request = await self._exchange_token_authorization_code(
717+
auth_code, code_verifier
718+
)
719+
720+
token_response = yield token_request
698721
await self._handle_token_response(token_response)
699722
except Exception: # pragma: no cover
700723
logger.exception("OAuth flow error")
@@ -715,7 +738,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
715738
self._select_scopes(response)
716739

717740
# Step 2b: Perform (re-)authorization and token exchange
718-
token_response = yield await self._perform_authorization()
741+
auth_result = await self._perform_authorization()
742+
if isinstance(auth_result, httpx.Request):
743+
token_request = auth_result
744+
else:
745+
auth_code, code_verifier = auth_result
746+
token_request = await self._exchange_token_authorization_code(
747+
auth_code, code_verifier
748+
)
749+
750+
token_response = yield token_request
719751
await self._handle_token_response(token_response)
720752
except Exception: # pragma: no cover
721753
logger.exception("OAuth flow error")

0 commit comments

Comments
 (0)