|
7 | 7 | import httpx |
8 | 8 | import jwt |
9 | 9 | import pytest |
10 | | -from pydantic import AnyUrl, AnyHttpUrl |
| 10 | +from pydantic import AnyHttpUrl, AnyUrl |
11 | 11 |
|
12 | 12 | from mcp.client.auth import OAuthTokenError |
13 | 13 | from mcp.client.auth.extensions.enterprise_managed_auth import ( |
@@ -506,7 +506,10 @@ async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any): |
506 | 506 |
|
507 | 507 |
|
508 | 508 | @pytest.mark.anyio |
509 | | -async def test_exchange_token_with_client_authentication(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): |
| 509 | +async def test_exchange_token_with_client_authentication( |
| 510 | + sample_id_token: str, sample_id_jag: str, |
| 511 | + mock_token_storage: Any |
| 512 | +): |
510 | 513 | """Test token exchange with client authentication.""" |
511 | 514 | from mcp.shared.auth import OAuthClientInformationFull |
512 | 515 |
|
@@ -563,7 +566,10 @@ async def test_exchange_token_with_client_authentication(sample_id_token: str, s |
563 | 566 |
|
564 | 567 |
|
565 | 568 | @pytest.mark.anyio |
566 | | -async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): |
| 569 | +async def test_exchange_token_with_client_id_only( |
| 570 | + sample_id_token: str, sample_id_jag: str, |
| 571 | + mock_token_storage: Any |
| 572 | +): |
567 | 573 | """Test token exchange with client_id but no client_secret (covers branch 232->235).""" |
568 | 574 | from mcp.shared.auth import OAuthClientInformationFull |
569 | 575 |
|
@@ -681,7 +687,9 @@ async def test_exchange_token_non_json_error_response(sample_id_token: str, mock |
681 | 687 |
|
682 | 688 |
|
683 | 689 | @pytest.mark.anyio |
684 | | -async def test_exchange_token_warning_for_non_na_token_type(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): |
| 690 | +async def test_exchange_token_warning_for_non_na_token_type( |
| 691 | + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any |
| 692 | +): |
685 | 693 | """Test token exchange logs warning for non-N_A token type.""" |
686 | 694 | token_exchange_params = TokenExchangeParameters.from_id_token( |
687 | 695 | id_token=sample_id_token, |
@@ -718,7 +726,7 @@ async def test_exchange_token_warning_for_non_na_token_type(sample_id_token: str |
718 | 726 | import logging |
719 | 727 |
|
720 | 728 | with patch.object( |
721 | | - logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" |
| 729 | + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" |
722 | 730 | ) as mock_warning: |
723 | 731 | id_jag = await provider.exchange_token_for_id_jag(mock_client) |
724 | 732 | assert id_jag == sample_id_jag |
@@ -972,6 +980,132 @@ async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage |
972 | 980 | await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) |
973 | 981 |
|
974 | 982 |
|
| 983 | +@pytest.mark.anyio |
| 984 | +async def test_exchange_token_with_client_info_but_no_client_id( |
| 985 | + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any |
| 986 | +): |
| 987 | + """Test token exchange when client_info exists but client_id is None (covers line 231).""" |
| 988 | + from mcp.shared.auth import OAuthClientInformationFull |
| 989 | + |
| 990 | + token_exchange_params = TokenExchangeParameters.from_id_token( |
| 991 | + id_token=sample_id_token, |
| 992 | + mcp_server_auth_issuer="https://auth.mcp-server.example/", |
| 993 | + mcp_server_resource_id="https://mcp-server.example/", |
| 994 | + scope="read write", |
| 995 | + ) |
| 996 | + |
| 997 | + provider = EnterpriseAuthOAuthClientProvider( |
| 998 | + server_url="https://mcp-server.example/", |
| 999 | + client_metadata=OAuthClientMetadata( |
| 1000 | + redirect_uris=[AnyUrl("http://localhost:8080/callback")], |
| 1001 | + client_name="Test Client", |
| 1002 | + ), |
| 1003 | + storage=mock_token_storage, |
| 1004 | + idp_token_endpoint="https://idp.example.com/oauth2/token", |
| 1005 | + token_exchange_params=token_exchange_params, |
| 1006 | + ) |
| 1007 | + |
| 1008 | + # Set client info with client_id=None |
| 1009 | + provider.context.client_info = OAuthClientInformationFull( |
| 1010 | + client_id=None, # This should skip the client_id assignment |
| 1011 | + client_secret="test-secret", |
| 1012 | + redirect_uris=[AnyUrl("http://localhost:8080/callback")], |
| 1013 | + ) |
| 1014 | + |
| 1015 | + # Mock HTTP response |
| 1016 | + mock_response = httpx.Response( |
| 1017 | + status_code=200, |
| 1018 | + json={ |
| 1019 | + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", |
| 1020 | + "access_token": sample_id_jag, |
| 1021 | + "token_type": "N_A", |
| 1022 | + "scope": "read write", |
| 1023 | + "expires_in": 300, |
| 1024 | + }, |
| 1025 | + ) |
| 1026 | + |
| 1027 | + mock_client = Mock(spec=httpx.AsyncClient) |
| 1028 | + mock_client.post = AsyncMock(return_value=mock_response) |
| 1029 | + |
| 1030 | + # Perform token exchange |
| 1031 | + id_jag = await provider.exchange_token_for_id_jag(mock_client) |
| 1032 | + |
| 1033 | + # Verify the ID-JAG was returned |
| 1034 | + assert id_jag == sample_id_jag |
| 1035 | + |
| 1036 | + # Verify client_id was not included (None), but client_secret was included |
| 1037 | + call_args = mock_client.post.call_args |
| 1038 | + assert "client_id" not in call_args[1]["data"] |
| 1039 | + assert call_args[1]["data"]["client_secret"] == "test-secret" |
| 1040 | + |
| 1041 | + |
| 1042 | +@pytest.mark.anyio |
| 1043 | +async def test_exchange_id_jag_with_client_info_but_no_client_id( |
| 1044 | + sample_id_jag: str, mock_token_storage: Any |
| 1045 | +): |
| 1046 | + """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302).""" |
| 1047 | + from pydantic import AnyHttpUrl |
| 1048 | + |
| 1049 | + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata |
| 1050 | + |
| 1051 | + token_exchange_params = TokenExchangeParameters.from_id_token( |
| 1052 | + id_token="dummy-token", |
| 1053 | + mcp_server_auth_issuer="https://auth.mcp-server.example/", |
| 1054 | + mcp_server_resource_id="https://mcp-server.example/", |
| 1055 | + ) |
| 1056 | + |
| 1057 | + provider = EnterpriseAuthOAuthClientProvider( |
| 1058 | + server_url="https://mcp-server.example/", |
| 1059 | + client_metadata=OAuthClientMetadata( |
| 1060 | + redirect_uris=[AnyUrl("http://localhost:8080/callback")], |
| 1061 | + ), |
| 1062 | + storage=mock_token_storage, |
| 1063 | + idp_token_endpoint="https://idp.example.com/oauth2/token", |
| 1064 | + token_exchange_params=token_exchange_params, |
| 1065 | + ) |
| 1066 | + |
| 1067 | + # Set up OAuth metadata |
| 1068 | + provider.context.oauth_metadata = OAuthMetadata( |
| 1069 | + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), |
| 1070 | + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), |
| 1071 | + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), |
| 1072 | + ) |
| 1073 | + |
| 1074 | + # Set client info with client_id=None |
| 1075 | + provider.context.client_info = OAuthClientInformationFull( |
| 1076 | + client_id=None, # This should skip the client_id assignment |
| 1077 | + client_secret="test-secret", |
| 1078 | + redirect_uris=[AnyUrl("http://localhost:8080/callback")], |
| 1079 | + ) |
| 1080 | + |
| 1081 | + # Mock HTTP response |
| 1082 | + mock_response = httpx.Response( |
| 1083 | + status_code=200, |
| 1084 | + json={ |
| 1085 | + "token_type": "Bearer", |
| 1086 | + "access_token": "mcp-access-token-12345", |
| 1087 | + "expires_in": 3600, |
| 1088 | + "scope": "read write", |
| 1089 | + }, |
| 1090 | + ) |
| 1091 | + |
| 1092 | + mock_client = Mock(spec=httpx.AsyncClient) |
| 1093 | + mock_client.post = AsyncMock(return_value=mock_response) |
| 1094 | + |
| 1095 | + # Perform JWT bearer grant |
| 1096 | + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) |
| 1097 | + |
| 1098 | + # Verify |
| 1099 | + assert token.access_token == "mcp-access-token-12345" |
| 1100 | + assert token.token_type == "Bearer" |
| 1101 | + assert token.expires_in == 3600 |
| 1102 | + |
| 1103 | + # Verify client_id was not included (None), but client_secret was included |
| 1104 | + call_args = mock_client.post.call_args |
| 1105 | + assert "client_id" not in call_args[1]["data"] |
| 1106 | + assert call_args[1]["data"]["client_secret"] == "test-secret" |
| 1107 | + |
| 1108 | + |
975 | 1109 | def test_validate_token_exchange_params_missing_audience(): |
976 | 1110 | """Test validation fails for missing audience.""" |
977 | 1111 | params = TokenExchangeParameters( |
|
0 commit comments