Skip to content

Commit aa2d1b9

Browse files
committed
refresh
1 parent d54ba93 commit aa2d1b9

File tree

2 files changed

+18
-115
lines changed

2 files changed

+18
-115
lines changed

tests/unit/test_token_federation.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,21 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
124124
mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"}
125125
mock_is_same_host.return_value = False
126126

127-
# Create a simple credentials provider that returns a fixed token
128-
external_token = "test_token"
129-
creds_provider = SimpleCredentialsProvider(external_token)
127+
# Create a mock credentials provider that can return different tokens
128+
mock_creds_provider = MagicMock()
129+
# Initial token factory
130+
initial_header_factory = MagicMock()
131+
initial_header_factory.return_value = {"Authorization": "Bearer initial_token"}
132+
# Fresh token factory for refresh
133+
fresh_header_factory = MagicMock()
134+
fresh_header_factory.return_value = {"Authorization": "Bearer fresh_token"}
135+
136+
# Configure the mock to return different header factories on consecutive calls
137+
mock_creds_provider.side_effect = [initial_header_factory, fresh_header_factory]
130138

131139
# Set up the token federation provider
132140
federation_provider = DatabricksTokenFederationProvider(
133-
creds_provider, "example.com", "client_id"
141+
mock_creds_provider, "example.com", "client_id"
134142
)
135143

136144
# Mock the token exchange to return a known token
@@ -143,8 +151,8 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
143151
headers_factory = federation_provider()
144152
headers = headers_factory()
145153

146-
# Verify the exchange happened
147-
mock_exchange_token.assert_called_with(external_token, "azure")
154+
# Verify the exchange happened with the initial token
155+
mock_exchange_token.assert_called_with("initial_token", "azure")
148156
self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1")
149157

150158
# Reset the mocks to track the next call
@@ -155,7 +163,7 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
155163
federation_provider.last_exchanged_token = Token(
156164
"exchanged_token_1", "Bearer", expiry=near_expiry
157165
)
158-
federation_provider.last_external_token = external_token
166+
federation_provider.last_external_token = "initial_token"
159167

160168
# Set up the mock to return a different token for the refresh
161169
mock_exchange_token.return_value = Token(
@@ -165,9 +173,9 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_
165173
# Make a second call which should trigger refresh
166174
headers = headers_factory()
167175

168-
# Verify the token was exchanged with the SAME external token (current implementation)
169-
# This is different from the JDBC driver approach which gets a fresh token
170-
mock_exchange_token.assert_called_once_with(external_token, "azure")
176+
# Verify a fresh token was requested from the credentials provider
177+
# and the exchange was performed with the fresh token
178+
mock_exchange_token.assert_called_once_with("fresh_token", "azure")
171179

172180
# Verify the headers contain the new token
173181
self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2")

tests/unit/test_token_federation_jdbc.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

0 commit comments

Comments
 (0)