Skip to content

Commit 032ff50

Browse files
committed
Move RFC7523 support into extension
1 parent aa5d820 commit 032ff50

File tree

5 files changed

+263
-238
lines changed

5 files changed

+263
-238
lines changed

src/mcp/client/auth/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
OAuth2 Authentication implementation for HTTPX.
3+
4+
Implements authorization code flow with PKCE and automatic token refresh.
5+
"""
6+
7+
from mcp.client.auth.oauth2 import * # noqa: F403
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import time
2+
from collections.abc import Awaitable, Callable
3+
from typing import Any
4+
from uuid import uuid4
5+
6+
import httpx
7+
import jwt
8+
from pydantic import BaseModel, Field
9+
10+
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
11+
from mcp.shared.auth import OAuthClientMetadata
12+
13+
14+
class JWTParameters(BaseModel):
15+
"""JWT parameters."""
16+
17+
assertion: str | None = Field(
18+
default=None,
19+
description="JWT assertion for JWT authentication. "
20+
"Will be used instead of generating a new assertion if provided.",
21+
)
22+
23+
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
24+
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
25+
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
26+
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
27+
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
28+
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
29+
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
30+
31+
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
32+
if self.assertion is not None:
33+
# Prebuilt JWT (e.g. acquired out-of-band)
34+
assertion = self.assertion
35+
else:
36+
if not self.jwt_signing_key:
37+
raise OAuthFlowError("Missing signing key for JWT bearer grant")
38+
if not self.issuer:
39+
raise OAuthFlowError("Missing issuer for JWT bearer grant")
40+
if not self.subject:
41+
raise OAuthFlowError("Missing subject for JWT bearer grant")
42+
43+
audience = self.audience if self.audience else with_audience_fallback
44+
if not audience:
45+
raise OAuthFlowError("Missing audience for JWT bearer grant")
46+
47+
now = int(time.time())
48+
claims: dict[str, Any] = {
49+
"iss": self.issuer,
50+
"sub": self.subject,
51+
"aud": audience,
52+
"exp": now + self.jwt_lifetime_seconds,
53+
"iat": now,
54+
"jti": str(uuid4()),
55+
}
56+
claims.update(self.claims or {})
57+
58+
assertion = jwt.encode(
59+
claims,
60+
self.jwt_signing_key,
61+
algorithm=self.jwt_signing_algorithm or "RS256",
62+
)
63+
return assertion
64+
65+
66+
class RFC7523OAuthClientProvider(OAuthClientProvider):
67+
"""OAuth client provider for RFC7532 clients."""
68+
69+
jwt_parameters: JWTParameters | None = None
70+
71+
def __init__(
72+
self,
73+
server_url: str,
74+
client_metadata: OAuthClientMetadata,
75+
storage: TokenStorage,
76+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
77+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
78+
timeout: float = 300.0,
79+
jwt_parameters: JWTParameters | None = None,
80+
) -> None:
81+
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
82+
self.jwt_parameters = jwt_parameters
83+
84+
async def _exchange_token_authorization_code(
85+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
86+
) -> httpx.Request:
87+
"""Build token exchange request for authorization_code flow."""
88+
token_data = token_data or {}
89+
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
90+
self._add_client_authentication_jwt(token_data=token_data)
91+
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
92+
93+
async def _perform_authorization(self) -> httpx.Request:
94+
"""Perform the authorization flow."""
95+
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
96+
token_request = await self._exchange_token_jwt_bearer()
97+
return token_request
98+
else:
99+
return await super()._perform_authorization()
100+
101+
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
102+
"""Add JWT assertion for client authentication to token endpoint parameters."""
103+
if not self.jwt_parameters:
104+
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
105+
106+
token_url = self._get_token_endpoint()
107+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)
108+
109+
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
110+
token_data["client_assertion"] = assertion
111+
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
112+
# We need to set the audience to the token endpoint, the audience is difference from the one in claims
113+
# it represents the resource server that will validate the token
114+
token_data["audience"] = self.context.get_resource_url()
115+
116+
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
117+
"""Build token exchange request for JWT bearer grant."""
118+
if not self.context.client_info:
119+
raise OAuthFlowError("Missing client info")
120+
if not self.jwt_parameters:
121+
raise OAuthFlowError("Missing JWT parameters")
122+
123+
token_url = self._get_token_endpoint()
124+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)
125+
126+
token_data = {
127+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
128+
"assertion": assertion,
129+
}
130+
131+
if self.context.should_include_resource_param(self.context.protocol_version):
132+
token_data["resource"] = self.context.get_resource_url()
133+
134+
if self.context.client_metadata.scope:
135+
token_data["scope"] = self.context.client_metadata.scope
136+
137+
return httpx.Request(
138+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
139+
)
Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
from dataclasses import dataclass, field
1616
from typing import Any, Protocol
1717
from urllib.parse import urlencode, urljoin, urlparse
18-
from uuid import uuid4
1918

2019
import anyio
2120
import httpx
22-
import jwt
2321
from pydantic import BaseModel, Field, ValidationError
2422

2523
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
@@ -63,58 +61,6 @@ def generate(cls) -> "PKCEParameters":
6361
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
6462

6563

66-
class JWTParameters(BaseModel):
67-
"""JWT parameters."""
68-
69-
assertion: str | None = Field(
70-
default=None,
71-
description="JWT assertion for JWT authentication. "
72-
"Will be used instead of generating a new assertion if provided.",
73-
)
74-
75-
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
76-
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
77-
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
78-
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
79-
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
80-
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
81-
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
82-
83-
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
84-
if self.assertion is not None:
85-
# Prebuilt JWT (e.g. acquired out-of-band)
86-
assertion = self.assertion
87-
else:
88-
if not self.jwt_signing_key:
89-
raise OAuthFlowError("Missing signing key for JWT bearer grant")
90-
if not self.issuer:
91-
raise OAuthFlowError("Missing issuer for JWT bearer grant")
92-
if not self.subject:
93-
raise OAuthFlowError("Missing subject for JWT bearer grant")
94-
95-
audience = self.audience if self.audience else with_audience_fallback
96-
if not audience:
97-
raise OAuthFlowError("Missing audience for JWT bearer grant")
98-
99-
now = int(time.time())
100-
claims: dict[str, Any] = {
101-
"iss": self.issuer,
102-
"sub": self.subject,
103-
"aud": audience,
104-
"exp": now + self.jwt_lifetime_seconds,
105-
"iat": now,
106-
"jti": str(uuid4()),
107-
}
108-
claims.update(self.claims or {})
109-
110-
assertion = jwt.encode(
111-
claims,
112-
self.jwt_signing_key,
113-
algorithm=self.jwt_signing_algorithm or "RS256",
114-
)
115-
return assertion
116-
117-
11864
class TokenStorage(Protocol):
11965
"""Protocol for token storage implementations."""
12066

@@ -686,79 +632,3 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
686632
# Retry with new tokens
687633
self._add_auth_header(request)
688634
yield request
689-
690-
691-
class RFC7523OAuthClientProvider(OAuthClientProvider):
692-
"""OAuth client provider for RFC7532 clients."""
693-
694-
jwt_parameters: JWTParameters | None = None
695-
696-
def __init__(
697-
self,
698-
server_url: str,
699-
client_metadata: OAuthClientMetadata,
700-
storage: TokenStorage,
701-
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
702-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
703-
timeout: float = 300.0,
704-
jwt_parameters: JWTParameters | None = None,
705-
) -> None:
706-
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
707-
self.jwt_parameters = jwt_parameters
708-
709-
async def _exchange_token_authorization_code(
710-
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
711-
) -> httpx.Request:
712-
"""Build token exchange request for authorization_code flow."""
713-
token_data = token_data or {}
714-
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
715-
self._add_client_authentication_jwt(token_data=token_data)
716-
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
717-
718-
async def _perform_authorization(self) -> httpx.Request:
719-
"""Perform the authorization flow."""
720-
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
721-
token_request = await self._exchange_token_jwt_bearer()
722-
return token_request
723-
else:
724-
return await super()._perform_authorization()
725-
726-
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
727-
"""Add JWT assertion for client authentication to token endpoint parameters."""
728-
if not self.jwt_parameters:
729-
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
730-
731-
token_url = self._get_token_endpoint()
732-
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)
733-
734-
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
735-
token_data["client_assertion"] = assertion
736-
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
737-
# We need to set the audience to the token endpoint, the audience is difference from the one in claims
738-
# it represents the resource server that will validate the token
739-
token_data["audience"] = self.context.get_resource_url()
740-
741-
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
742-
"""Build token exchange request for JWT bearer grant."""
743-
if not self.context.client_info:
744-
raise OAuthFlowError("Missing client info")
745-
if not self.jwt_parameters:
746-
raise OAuthFlowError("Missing JWT parameters")
747-
748-
token_url = self._get_token_endpoint()
749-
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)
750-
751-
token_data = {
752-
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
753-
"assertion": assertion,
754-
}
755-
756-
if self.context.should_include_resource_param(self.context.protocol_version):
757-
token_data["resource"] = self.context.get_resource_url()
758-
759-
if self.context.client_metadata.scope:
760-
token_data["scope"] = self.context.client_metadata.scope
761-
762-
return httpx.Request(
763-
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
764-
)

0 commit comments

Comments
 (0)