@@ -124,13 +124,21 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
124124 mock_parse_jwt .return_value = {"iss" : "https://login.microsoftonline.com/tenant" }
125125 mock_is_same_host .return_value = False
126126
127- # Create a simple credentials provider that returns a fixed token
128- external_token = "test_token"
129- creds_provider = SimpleCredentialsProvider (external_token )
127+ # Create a mock credentials provider that can return different tokens
128+ mock_creds_provider = MagicMock ()
129+ # Initial token factory
130+ initial_header_factory = MagicMock ()
131+ initial_header_factory .return_value = {"Authorization" : "Bearer initial_token" }
132+ # Fresh token factory for refresh
133+ fresh_header_factory = MagicMock ()
134+ fresh_header_factory .return_value = {"Authorization" : "Bearer fresh_token" }
135+
136+ # Configure the mock to return different header factories on consecutive calls
137+ mock_creds_provider .side_effect = [initial_header_factory , fresh_header_factory ]
130138
131139 # Set up the token federation provider
132140 federation_provider = DatabricksTokenFederationProvider (
133- creds_provider , "example.com" , "client_id"
141+ mock_creds_provider , "example.com" , "client_id"
134142 )
135143
136144 # Mock the token exchange to return a known token
@@ -143,8 +151,8 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
143151 headers_factory = federation_provider ()
144152 headers = headers_factory ()
145153
146- # Verify the exchange happened
147- mock_exchange_token .assert_called_with (external_token , "azure" )
154+ # Verify the exchange happened with the initial token
155+ mock_exchange_token .assert_called_with ("initial_token" , "azure" )
148156 self .assertEqual (headers ["Authorization" ], "Bearer exchanged_token_1" )
149157
150158 # Reset the mocks to track the next call
@@ -155,7 +163,7 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
155163 federation_provider .last_exchanged_token = Token (
156164 "exchanged_token_1" , "Bearer" , expiry = near_expiry
157165 )
158- federation_provider .last_external_token = external_token
166+ federation_provider .last_external_token = "initial_token"
159167
160168 # Set up the mock to return a different token for the refresh
161169 mock_exchange_token .return_value = Token (
@@ -165,9 +173,9 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
165173 # Make a second call which should trigger refresh
166174 headers = headers_factory ()
167175
168- # Verify the token was exchanged with the SAME external token (current implementation)
169- # This is different from the JDBC driver approach which gets a fresh token
170- mock_exchange_token .assert_called_once_with (external_token , "azure" )
176+ # Verify a fresh token was requested from the credentials provider
177+ # and the exchange was performed with the fresh token
178+ mock_exchange_token .assert_called_once_with ("fresh_token" , "azure" )
171179
172180 # Verify the headers contain the new token
173181 self .assertEqual (headers ["Authorization" ], "Bearer exchanged_token_2" )
0 commit comments