Skip to content

Commit 5fb2c0f

Browse files
committed
- Resolved pyright checks error.
1 parent db2f02c commit 5fb2c0f

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/mcp/client/auth/extensions/enterprise_managed_auth.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import logging
9-
from typing import Any
9+
from collections.abc import Awaitable, Callable
1010

1111
import httpx
1212
from pydantic import BaseModel, Field
@@ -166,8 +166,8 @@ def __init__(
166166
storage: TokenStorage,
167167
idp_token_endpoint: str,
168168
token_exchange_params: TokenExchangeParameters,
169-
redirect_handler: Any = None,
170-
callback_handler: Any = None,
169+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
170+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
171171
timeout: float = 300.0,
172172
) -> None:
173173
"""
@@ -228,7 +228,8 @@ async def exchange_token_for_id_jag(
228228

229229
# Add client authentication if needed
230230
if self.context.client_info:
231-
token_data["client_id"] = self.context.client_info.client_id
231+
if self.context.client_info.client_id is not None:
232+
token_data["client_id"] = self.context.client_info.client_id
232233
if self.context.client_info.client_secret is not None:
233234
token_data["client_secret"] = self.context.client_info.client_secret
234235

@@ -240,11 +241,11 @@ async def exchange_token_for_id_jag(
240241
)
241242

242243
if response.status_code != 200:
243-
error_data: dict[str, str] = (
244+
error_data: dict[str, object] = (
244245
response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
245246
)
246-
error: str = error_data.get("error", "unknown_error")
247-
error_description: str = error_data.get("error_description", "Token exchange failed")
247+
error = str(error_data.get("error", "unknown_error"))
248+
error_description = str(error_data.get("error_description", "Token exchange failed"))
248249
raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}")
249250

250251
# Parse response
@@ -298,7 +299,8 @@ async def exchange_id_jag_for_access_token(
298299

299300
# Add client authentication
300301
if self.context.client_info:
301-
token_data["client_id"] = self.context.client_info.client_id
302+
if self.context.client_info.client_id is not None:
303+
token_data["client_id"] = self.context.client_info.client_id
302304
if self.context.client_info.client_secret is not None:
303305
token_data["client_secret"] = self.context.client_info.client_secret
304306

@@ -310,11 +312,11 @@ async def exchange_id_jag_for_access_token(
310312
)
311313

312314
if response.status_code != 200:
313-
error_data: dict[str, str] = (
315+
error_data: dict[str, object] = (
314316
response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
315317
)
316-
error: str = error_data.get("error", "unknown_error")
317-
error_description: str = error_data.get("error_description", "JWT bearer grant failed")
318+
error = str(error_data.get("error", "unknown_error"))
319+
error_description = str(error_data.get("error_description", "JWT bearer grant failed"))
318320
raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}")
319321

320322
# Parse OAuth token response

0 commit comments

Comments
 (0)