Skip to content

Commit 94850e7

Browse files
authored
implement-rfc-8693-token-exchange-in-mcp-sdk
Add OAuth token exchange support
2 parents b46aac4 + 2daea3f commit 94850e7

File tree

9 files changed

+310
-5
lines changed

9 files changed

+310
-5
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,11 @@ async def main():
814814
The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers:
815815

816816
```python
817-
from mcp.client.auth import OAuthClientProvider, TokenStorage
817+
from mcp.client.auth import (
818+
OAuthClientProvider,
819+
TokenExchangeProvider,
820+
TokenStorage,
821+
)
818822
from mcp.client.session import ClientSession
819823
from mcp.client.streamable_http import streamablehttp_client
820824
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
@@ -854,6 +858,20 @@ async def main():
854858
# For machine-to-machine scenarios, use ClientCredentialsProvider
855859
# instead of OAuthClientProvider.
856860

861+
# If you already have a user token from another provider,
862+
# you can exchange it for an MCP token using TokenExchangeProvider.
863+
token_exchange_auth = TokenExchangeProvider(
864+
server_url="https://api.example.com",
865+
client_metadata=OAuthClientMetadata(
866+
client_name="My Client",
867+
redirect_uris=["http://localhost:3000/callback"],
868+
grant_types=["urn:ietf:params:oauth:grant-type:token-exchange"],
869+
response_types=["code"],
870+
),
871+
storage=CustomTokenStorage(),
872+
subject_token_supplier=lambda: "user_token",
873+
)
874+
857875
# Use with streamable HTTP client
858876
async with streamablehttp_client(
859877
"https://api.example.com/mcp", auth=oauth_auth

src/mcp/client/auth.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,3 +678,90 @@ async def async_auth_flow(
678678

679679
if response.status_code == 401:
680680
self._current_tokens = None
681+
682+
683+
class TokenExchangeProvider(ClientCredentialsProvider):
684+
"""OAuth2 token exchange based on RFC 8693."""
685+
686+
def __init__(
687+
self,
688+
server_url: str,
689+
client_metadata: OAuthClientMetadata,
690+
storage: TokenStorage,
691+
subject_token_supplier: Callable[[], Awaitable[str]],
692+
subject_token_type: str = "urn:ietf:params:oauth:token-type:access_token",
693+
actor_token_supplier: Callable[[], Awaitable[str]] | None = None,
694+
actor_token_type: str | None = None,
695+
audience: str | None = None,
696+
resource: str | None = None,
697+
timeout: float = 300.0,
698+
):
699+
super().__init__(server_url, client_metadata, storage, timeout)
700+
self.subject_token_supplier = subject_token_supplier
701+
self.subject_token_type = subject_token_type
702+
self.actor_token_supplier = actor_token_supplier
703+
self.actor_token_type = actor_token_type
704+
self.audience = audience
705+
self.resource = resource
706+
707+
async def _request_token(self) -> None:
708+
if not self._metadata:
709+
self._metadata = await _discover_oauth_metadata(self.server_url)
710+
711+
client_info = await self._get_or_register_client()
712+
713+
if self._metadata and self._metadata.token_endpoint:
714+
token_url = str(self._metadata.token_endpoint)
715+
else:
716+
auth_base_url = _get_authorization_base_url(self.server_url)
717+
token_url = urljoin(auth_base_url, "/token")
718+
719+
subject_token = await self.subject_token_supplier()
720+
actor_token = (
721+
await self.actor_token_supplier() if self.actor_token_supplier else None
722+
)
723+
724+
token_data = {
725+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
726+
"client_id": client_info.client_id,
727+
"subject_token": subject_token,
728+
"subject_token_type": self.subject_token_type,
729+
}
730+
731+
if client_info.client_secret:
732+
token_data["client_secret"] = client_info.client_secret
733+
734+
if actor_token:
735+
token_data["actor_token"] = actor_token
736+
if self.actor_token_type:
737+
token_data["actor_token_type"] = self.actor_token_type
738+
if self.audience:
739+
token_data["audience"] = self.audience
740+
if self.resource:
741+
token_data["resource"] = self.resource
742+
if self.client_metadata.scope:
743+
token_data["scope"] = self.client_metadata.scope
744+
745+
async with httpx.AsyncClient() as client:
746+
response = await client.post(
747+
token_url,
748+
data=token_data,
749+
headers={"Content-Type": "application/x-www-form-urlencoded"},
750+
timeout=30.0,
751+
)
752+
753+
if response.status_code != 200:
754+
raise Exception(
755+
f"Token request failed: {response.status_code} {response.text}"
756+
)
757+
758+
token_response = OAuthToken.model_validate(response.json())
759+
await self._validate_token_scopes(token_response)
760+
761+
if token_response.expires_in:
762+
self._token_expiry_time = time.time() + token_response.expires_in
763+
else:
764+
self._token_expiry_time = None
765+
766+
await self.storage.set_tokens(token_response)
767+
self._current_tokens = token_response

src/mcp/server/auth/handlers/register.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ async def handle(self, request: Request) -> Response:
7878
valid_sets = [
7979
{"authorization_code", "refresh_token"},
8080
{"client_credentials"},
81+
{"urn:ietf:params:oauth:grant-type:token-exchange"},
8182
]
8283

8384
if grant_types_set not in valid_sets:
@@ -86,7 +87,7 @@ async def handle(self, request: Request) -> Response:
8687
error="invalid_client_metadata",
8788
error_description=(
8889
"grant_types must be authorization_code and refresh_token "
89-
"or client_credentials"
90+
"or client_credentials or token exchange"
9091
),
9192
),
9293
status_code=400,

src/mcp/server/auth/handlers/token.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,39 @@ class ClientCredentialsRequest(BaseModel):
5555
client_secret: str | None = None
5656

5757

58+
class TokenExchangeRequest(BaseModel):
59+
"""RFC 8693 token exchange request."""
60+
61+
grant_type: Literal["urn:ietf:params:oauth:grant-type:token-exchange"]
62+
subject_token: str = Field(..., description="Token to exchange")
63+
subject_token_type: str = Field(..., description="Type of the subject token")
64+
actor_token: str | None = Field(None, description="Optional actor token")
65+
actor_token_type: str | None = Field(
66+
None, description="Type of the actor token if provided"
67+
)
68+
resource: str | None = None
69+
audience: str | None = None
70+
scope: str | None = None
71+
client_id: str
72+
client_secret: str | None = None
73+
74+
5875
class TokenRequest(
5976
RootModel[
6077
Annotated[
61-
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
78+
AuthorizationCodeRequest
79+
| RefreshTokenRequest
80+
| ClientCredentialsRequest
81+
| TokenExchangeRequest,
6282
Field(discriminator="grant_type"),
6383
]
6484
]
6585
):
6686
root: Annotated[
67-
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
87+
AuthorizationCodeRequest
88+
| RefreshTokenRequest
89+
| ClientCredentialsRequest
90+
| TokenExchangeRequest,
6891
Field(discriminator="grant_type"),
6992
]
7093

@@ -232,6 +255,27 @@ async def handle(self, request: Request):
232255
)
233256
)
234257

258+
case TokenExchangeRequest():
259+
scopes = token_request.scope.split(" ") if token_request.scope else []
260+
try:
261+
tokens = await self.provider.exchange_token(
262+
client_info,
263+
token_request.subject_token,
264+
token_request.subject_token_type,
265+
token_request.actor_token,
266+
token_request.actor_token_type,
267+
scopes,
268+
token_request.audience,
269+
token_request.resource,
270+
)
271+
except TokenError as e:
272+
return self.response(
273+
TokenErrorResponse(
274+
error=e.error,
275+
error_description=e.error_description,
276+
)
277+
)
278+
235279
case RefreshTokenRequest():
236280
refresh_token = await self.provider.load_refresh_token(
237281
client_info, token_request.refresh_token

src/mcp/server/auth/provider.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class AuthorizeError(Exception):
8080
"unauthorized_client",
8181
"unsupported_grant_type",
8282
"invalid_scope",
83+
"invalid_target",
8384
]
8485

8586

@@ -253,6 +254,20 @@ async def exchange_client_credentials(
253254
"""Exchange client credentials for an access token."""
254255
...
255256

257+
async def exchange_token(
258+
self,
259+
client: OAuthClientInformationFull,
260+
subject_token: str,
261+
subject_token_type: str,
262+
actor_token: str | None,
263+
actor_token_type: str | None,
264+
scope: list[str] | None,
265+
audience: str | None,
266+
resource: str | None,
267+
) -> OAuthToken:
268+
"""Exchange an external token for an MCP access token."""
269+
...
270+
256271
async def load_access_token(self, token: str) -> AccessTokenT | None:
257272
"""
258273
Loads an access token by its token.

src/mcp/server/auth/routes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def build_metadata(
168168
"authorization_code",
169169
"refresh_token",
170170
"client_credentials",
171+
"urn:ietf:params:oauth:grant-type:token-exchange",
171172
],
172173
token_endpoint_auth_methods_supported=["client_secret_post"],
173174
token_endpoint_auth_signing_alg_values_supported=None,

src/mcp/shared/auth.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class OAuthToken(BaseModel):
1313
expires_in: int | None = None
1414
scope: str | None = None
1515
refresh_token: str | None = None
16+
issued_token_type: str | None = None
1617

1718

1819
class InvalidScopeError(Exception):
@@ -41,7 +42,12 @@ class OAuthClientMetadata(BaseModel):
4142
)
4243
# grant_types: support authorization_code, refresh_token, client_credentials
4344
grant_types: list[
44-
Literal["authorization_code", "refresh_token", "client_credentials"]
45+
Literal[
46+
"authorization_code",
47+
"refresh_token",
48+
"client_credentials",
49+
"urn:ietf:params:oauth:grant-type:token-exchange",
50+
]
4551
] = [
4652
"authorization_code",
4753
"refresh_token",
@@ -121,6 +127,7 @@ class OAuthMetadata(BaseModel):
121127
"authorization_code",
122128
"refresh_token",
123129
"client_credentials",
130+
"urn:ietf:params:oauth:grant-type:token-exchange",
124131
]
125132
]
126133
| None

tests/client/test_auth.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tests for OAuth client authentication implementation.
33
"""
44

5+
import asyncio
56
import base64
67
import hashlib
78
import time
@@ -15,6 +16,7 @@
1516
from mcp.client.auth import (
1617
ClientCredentialsProvider,
1718
OAuthClientProvider,
19+
TokenExchangeProvider,
1820
_discover_oauth_metadata,
1921
_get_authorization_base_url,
2022
)
@@ -144,6 +146,16 @@ async def client_credentials_provider(client_credentials_metadata, mock_storage)
144146
)
145147

146148

149+
@pytest.fixture
150+
async def token_exchange_provider(client_credentials_metadata, mock_storage):
151+
return TokenExchangeProvider(
152+
server_url="https://api.example.com/v1/mcp",
153+
client_metadata=client_credentials_metadata,
154+
storage=mock_storage,
155+
subject_token_supplier=lambda: asyncio.sleep(0, result="user_token"),
156+
)
157+
158+
147159
class TestOAuthClientProvider:
148160
"""Test OAuth client provider functionality."""
149161

@@ -1064,3 +1076,36 @@ async def test_async_auth_flow(self, client_credentials_provider, oauth_token):
10641076
await auth_flow.asend(mock_response)
10651077
except StopAsyncIteration:
10661078
pass
1079+
1080+
1081+
class TestTokenExchangeProvider:
1082+
@pytest.mark.anyio
1083+
async def test_request_token_success(
1084+
self,
1085+
token_exchange_provider,
1086+
oauth_metadata,
1087+
oauth_client_info,
1088+
oauth_token,
1089+
):
1090+
token_exchange_provider._metadata = oauth_metadata
1091+
token_exchange_provider._client_info = oauth_client_info
1092+
1093+
token_json = oauth_token.model_dump(by_alias=True, mode="json")
1094+
token_json.pop("refresh_token", None)
1095+
1096+
with patch("httpx.AsyncClient") as mock_client_class:
1097+
mock_client = AsyncMock()
1098+
mock_client_class.return_value.__aenter__.return_value = mock_client
1099+
1100+
mock_response = Mock()
1101+
mock_response.status_code = 200
1102+
mock_response.json.return_value = token_json
1103+
mock_client.post.return_value = mock_response
1104+
1105+
await token_exchange_provider.ensure_token()
1106+
1107+
mock_client.post.assert_called_once()
1108+
assert (
1109+
token_exchange_provider._current_tokens.access_token
1110+
== oauth_token.access_token
1111+
)

0 commit comments

Comments
 (0)