From 23746118951e37b4ddad7bf971c252acc4051319 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 28 May 2025 16:58:06 +0000 Subject: [PATCH 01/49] First draft --- .../core/oauth/RefreshableTokenSource.java | 96 ++++++++++++++- .../com/databricks/sdk/core/oauth/Token.java | 9 ++ .../oauth/RefreshableTokenSourceTest.java | 114 ++++++++++++++++++ 3 files changed, 216 insertions(+), 3 deletions(-) create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index e93f91ae5..efdc29469 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -5,10 +5,12 @@ import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; +import java.time.Duration; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; import java.util.Base64; import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.apache.http.HttpHeaders; /** @@ -18,7 +20,20 @@ * at least 10 seconds until expiry). If not, refresh() is called first to refresh the token. */ public abstract class RefreshableTokenSource implements TokenSource { + + private enum TokenState { + FRESH, // The token is valid. + STALE, // The token is valid but will expire soon. + EXPIRED // The token has expired and cannot be used without refreshing. + } + + private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); + protected Token token; + private boolean asyncEnabled = false; + private Duration staleDuration = DEFAULT_STALE_DURATION; + private boolean refreshInProgress = false; + private boolean lastRefreshSucceeded = true; public RefreshableTokenSource() {} @@ -26,6 +41,11 @@ public RefreshableTokenSource(Token token) { this.token = token; } + public RefreshableTokenSource enableAsyncRefresh(boolean enabled) { + this.asyncEnabled = enabled; + return this; + } + /** * Helper method implementing OAuth token refresh. * @@ -80,9 +100,79 @@ protected static Token retrieveToken( protected abstract Token refresh(); public synchronized Token getToken() { - if (token == null || !token.isValid()) { - token = refresh(); + if (!asyncEnabled) { + return getTokenBlocking(); + } + return getTokenAsync(); + } + + protected TokenState getTokenState() { + if (token == null) { + return TokenState.EXPIRED; + } + Duration lifeTime = token.getLifetime(); + if (lifeTime.compareTo(Duration.ZERO) <= 0) { + return TokenState.EXPIRED; + } + if (lifeTime.compareTo(staleDuration) <= 0) { + return TokenState.STALE; + } + return TokenState.FRESH; + } + + protected synchronized Token getTokenBlocking() { + refreshInProgress = false; + TokenState state = getTokenState(); + if (state != TokenState.EXPIRED) { + return token; + } + try { + Token newToken = refresh(); + token = newToken; + return newToken; + } catch (Exception e) { + lastRefreshSucceeded = false; + throw e; + } + } + + protected Token getTokenAsync() { + TokenState state; + Token currentToken; + synchronized (this) { + state = getTokenState(); + currentToken = token; + } + if (state == TokenState.FRESH) { + return currentToken; + } + if (state == TokenState.STALE) { + triggerAsyncRefresh(); + return token; + } else { + return getTokenBlocking(); + } + } + + protected synchronized void triggerAsyncRefresh() { + if (!refreshInProgress && lastRefreshSucceeded) { + refreshInProgress = true; + CompletableFuture.runAsync( + () -> { + try { + Token newToken = refresh(); + synchronized (this) { + token = newToken; + lastRefreshSucceeded = true; + refreshInProgress = false; + } + } catch (Exception e) { + synchronized (this) { + lastRefreshSucceeded = false; + refreshInProgress = false; + } + } + }); } - return token; } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index f0fd72f68..5aaa012da 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -4,6 +4,7 @@ import com.databricks.sdk.core.utils.SystemClockSupplier; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.time.Duration; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; import java.util.Objects; @@ -91,4 +92,12 @@ public String getRefreshToken() { public String getAccessToken() { return accessToken; } + + public LocalDateTime getExpiry() { + return this.expiry; + } + + public Duration getLifetime() { + return Duration.between(LocalDateTime.now(clockSupplier.getClock()), this.expiry); + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java new file mode 100644 index 000000000..e543117c9 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -0,0 +1,114 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.sdk.core.utils.FakeClockSupplier; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +public class RefreshableTokenSourceTest { + @Test + void testAsyncRefresh() throws Exception { + // Set up a fake clock and initial token that is about to become stale + Instant now = Instant.parse("2023-10-18T12:00:00.00Z"); + ZoneId zoneId = ZoneId.of("UTC"); + FakeClockSupplier fakeClock = new FakeClockSupplier(now, zoneId); + LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); + + // Token expires in 2 minutes (less than the default stale duration of 3 minutes) + Token initialToken = + new Token("initial-token", "Bearer", null, currentTime.plusMinutes(2), fakeClock); + + CountDownLatch refreshCalled = new CountDownLatch(1); + Token refreshedToken = + new Token("refreshed-token", "Bearer", null, currentTime.plusMinutes(10), fakeClock); + + // Subclass with a refresh() that signals when called + RefreshableTokenSource source = + new RefreshableTokenSource(initialToken) { + @Override + protected Token refresh() { + refreshCalled.countDown(); + return refreshedToken; + } + }.enableAsyncRefresh(true); + + // First call should return the stale token and trigger async refresh + Token token1 = source.getToken(); + assertEquals( + "initial-token", token1.getAccessToken(), "Should return the stale token immediately"); + + // Wait for async refresh to complete (with timeout) + boolean refreshed = refreshCalled.await(2, TimeUnit.SECONDS); + assertTrue(refreshed, "Async refresh should have been triggered"); + + // After refresh, getToken should return the refreshed token + // (may need to wait a bit for the token to be set) + for (int i = 0; i < 10; i++) { + Token token2 = source.getToken(); + if ("refreshed-token".equals(token2.getAccessToken())) { + return; // Success + } + Thread.sleep(100); + } + fail("Token was not refreshed asynchronously"); + } + + @Test + void testSyncRefreshWhenExpired() { + // Set up a fake clock and an expired token + Instant now = Instant.parse("2023-10-18T12:00:00.00Z"); + ZoneId zoneId = ZoneId.of("UTC"); + FakeClockSupplier fakeClock = new FakeClockSupplier(now, zoneId); + LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); + + Token expiredToken = + new Token("expired-token", "Bearer", null, currentTime.minusMinutes(1), fakeClock); + Token refreshedToken = + new Token("refreshed-token", "Bearer", null, currentTime.plusMinutes(10), fakeClock); + + final boolean[] refreshCalled = {false}; + RefreshableTokenSource source = + new RefreshableTokenSource(expiredToken) { + @Override + protected Token refresh() { + refreshCalled[0] = true; + return refreshedToken; + } + }.enableAsyncRefresh(true); + + // Should call refresh synchronously and return the refreshed token + Token token = source.getToken(); + assertTrue(refreshCalled[0], "refresh() should be called synchronously for expired token"); + assertEquals("refreshed-token", token.getAccessToken(), "Should return the refreshed token"); + } + + @Test + void testNoRefreshWhenTokenIsFresh() { + // Set up a fake clock and a fresh token + Instant now = Instant.parse("2023-10-18T12:00:00.00Z"); + ZoneId zoneId = ZoneId.of("UTC"); + FakeClockSupplier fakeClock = new FakeClockSupplier(now, zoneId); + LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); + + Token freshToken = + new Token("fresh-token", "Bearer", null, currentTime.plusMinutes(10), fakeClock); + + RefreshableTokenSource source = + new RefreshableTokenSource(freshToken) { + @Override + protected Token refresh() { + fail("refresh() should not be called when token is fresh"); + return null; + } + }; + + // Should return the fresh token and never call refresh + Token token = source.getToken(); + assertEquals("fresh-token", token.getAccessToken(), "Should return the fresh token"); + } +} From 5148136c960d522333dcf1456343af993d3670f7 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 28 May 2025 17:24:32 +0000 Subject: [PATCH 02/49] Update test --- .../oauth/RefreshableTokenSourceTest.java | 121 ++++++++++++++---- 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index e543117c9..15c7768dd 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -11,21 +11,29 @@ import org.junit.jupiter.api.Test; public class RefreshableTokenSourceTest { + private static class MutableBoolean { + boolean value = false; + } + private static final String TOKEN_TYPE = "Bearer"; + private static final String INITIAL_TOKEN = "initial-token"; + private static final String REFRESHED_TOKEN = "refreshed-token"; + private static final String EXPIRED_TOKEN = "expired-token"; + private static final String FRESH_TOKEN = "fresh-token"; + private static final Instant FIXED_INSTANT = Instant.parse("2023-10-18T12:00:00.00Z"); + private static final ZoneId ZONE_ID = ZoneId.of("UTC"); + @Test void testAsyncRefresh() throws Exception { - // Set up a fake clock and initial token that is about to become stale - Instant now = Instant.parse("2023-10-18T12:00:00.00Z"); - ZoneId zoneId = ZoneId.of("UTC"); - FakeClockSupplier fakeClock = new FakeClockSupplier(now, zoneId); + FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); // Token expires in 2 minutes (less than the default stale duration of 3 minutes) Token initialToken = - new Token("initial-token", "Bearer", null, currentTime.plusMinutes(2), fakeClock); + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); CountDownLatch refreshCalled = new CountDownLatch(1); Token refreshedToken = - new Token("refreshed-token", "Bearer", null, currentTime.plusMinutes(10), fakeClock); + new Token(REFRESHED_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); // Subclass with a refresh() that signals when called RefreshableTokenSource source = @@ -40,7 +48,7 @@ protected Token refresh() { // First call should return the stale token and trigger async refresh Token token1 = source.getToken(); assertEquals( - "initial-token", token1.getAccessToken(), "Should return the stale token immediately"); + INITIAL_TOKEN, token1.getAccessToken(), "Should return the stale token immediately"); // Wait for async refresh to complete (with timeout) boolean refreshed = refreshCalled.await(2, TimeUnit.SECONDS); @@ -50,7 +58,7 @@ protected Token refresh() { // (may need to wait a bit for the token to be set) for (int i = 0; i < 10; i++) { Token token2 = source.getToken(); - if ("refreshed-token".equals(token2.getAccessToken())) { + if (REFRESHED_TOKEN.equals(token2.getAccessToken())) { return; // Success } Thread.sleep(100); @@ -60,43 +68,37 @@ protected Token refresh() { @Test void testSyncRefreshWhenExpired() { - // Set up a fake clock and an expired token - Instant now = Instant.parse("2023-10-18T12:00:00.00Z"); - ZoneId zoneId = ZoneId.of("UTC"); - FakeClockSupplier fakeClock = new FakeClockSupplier(now, zoneId); + FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); Token expiredToken = - new Token("expired-token", "Bearer", null, currentTime.minusMinutes(1), fakeClock); + new Token(EXPIRED_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); Token refreshedToken = - new Token("refreshed-token", "Bearer", null, currentTime.plusMinutes(10), fakeClock); + new Token(REFRESHED_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); - final boolean[] refreshCalled = {false}; + MutableBoolean refreshCalled = new MutableBoolean(); RefreshableTokenSource source = new RefreshableTokenSource(expiredToken) { @Override protected Token refresh() { - refreshCalled[0] = true; + refreshCalled.value = true; return refreshedToken; } }.enableAsyncRefresh(true); // Should call refresh synchronously and return the refreshed token Token token = source.getToken(); - assertTrue(refreshCalled[0], "refresh() should be called synchronously for expired token"); - assertEquals("refreshed-token", token.getAccessToken(), "Should return the refreshed token"); + assertTrue(refreshCalled.value, "refresh() should be called synchronously for expired token"); + assertEquals(REFRESHED_TOKEN, token.getAccessToken(), "Should return the refreshed token"); } @Test void testNoRefreshWhenTokenIsFresh() { - // Set up a fake clock and a fresh token - Instant now = Instant.parse("2023-10-18T12:00:00.00Z"); - ZoneId zoneId = ZoneId.of("UTC"); - FakeClockSupplier fakeClock = new FakeClockSupplier(now, zoneId); + FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); Token freshToken = - new Token("fresh-token", "Bearer", null, currentTime.plusMinutes(10), fakeClock); + new Token(FRESH_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); RefreshableTokenSource source = new RefreshableTokenSource(freshToken) { @@ -109,6 +111,77 @@ protected Token refresh() { // Should return the fresh token and never call refresh Token token = source.getToken(); - assertEquals("fresh-token", token.getAccessToken(), "Should return the fresh token"); + assertEquals(FRESH_TOKEN, token.getAccessToken(), "Should return the fresh token"); + } + + @Test + void testRefreshThrowsException() { + FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); + LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); + + Token expiredToken = + new Token(EXPIRED_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); + + RefreshableTokenSource source = + new RefreshableTokenSource(expiredToken) { + @Override + protected Token refresh() { + throw new RuntimeException("Simulated refresh failure"); + } + }; + + RuntimeException thrown = assertThrows( + RuntimeException.class, + source::getToken, + "getToken() should propagate exception from refresh() when token is expired"); + assertEquals("Simulated refresh failure", thrown.getMessage()); + } + + @Test + void testFailedAsyncRefreshForcesNextRefreshToBeSynchronous() throws Exception { + FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); + LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); + // Token is stale (expires in 2 minutes) + Token staleToken = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); + + class TestSource extends RefreshableTokenSource { + int refreshCallCount = 0; + boolean failFirst = true; + TestSource(Token token) { super(token); } + @Override + protected Token refresh() { + refreshCallCount++; + if (failFirst) { + failFirst = false; + throw new RuntimeException("Simulated async failure"); + } + throw new RuntimeException("Simulated sync failure"); + } + } + + TestSource source = new TestSource(staleToken); + source.enableAsyncRefresh(true); + + // First call triggers async refresh, which fails + source.getToken(); + Thread.sleep(300); // Give time for async refresh to run + assertEquals(1, source.refreshCallCount, "refresh() should have been called once (async, failed)"); + + // Token is still stale, so next call should NOT trigger another refresh + source.getToken(); + Thread.sleep(200); + assertEquals(1, source.refreshCallCount, "refresh() should NOT be called again while stale after async failure"); + + // Advance the clock so the token is now expired + source.token = new Token( + INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); + + // Now getToken() should call refresh synchronously and throw + RuntimeException thrown = assertThrows( + RuntimeException.class, + source::getToken, + "getToken() should call refresh synchronously and propagate exception when expired"); + assertEquals("Simulated sync failure", thrown.getMessage()); + assertEquals(2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); } } From 78c03a8a9fea6a9ea816ea91fb73988de90a5e07 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 30 May 2025 09:32:17 +0000 Subject: [PATCH 03/49] Clean up unit tests --- .../core/oauth/RefreshableTokenSource.java | 66 +++++- .../oauth/RefreshableTokenSourceTest.java | 220 ++++++++---------- 2 files changed, 159 insertions(+), 127 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index efdc29469..74097906c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -16,31 +16,54 @@ /** * An OAuth TokenSource which can be refreshed. * - *

Calls to getToken() will first check if the token is still valid (currently defined by having - * at least 10 seconds until expiry). If not, refresh() is called first to refresh the token. + *

This class supports both synchronous and asynchronous token refresh. When async is enabled, + * stale tokens will trigger a background refresh, while expired tokens will block until a new token + * is fetched. */ public abstract class RefreshableTokenSource implements TokenSource { + /** + * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: + * Token is valid but will expire soon. EXPIRED: Token has expired and must be refreshed. + */ private enum TokenState { - FRESH, // The token is valid. - STALE, // The token is valid but will expire soon. - EXPIRED // The token has expired and cannot be used without refreshing. + FRESH, + STALE, + EXPIRED } + // Default duration before expiry to consider a token as 'stale'. private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); + /** The current OAuth token. May be null if not yet fetched. */ protected Token token; + /** Whether asynchronous refresh is enabled. */ private boolean asyncEnabled = false; + /** Duration before expiry to consider a token as 'stale'. */ private Duration staleDuration = DEFAULT_STALE_DURATION; + /** Whether a refresh is currently in progress (for async refresh). */ private boolean refreshInProgress = false; + /** Whether the last refresh attempt succeeded. */ private boolean lastRefreshSucceeded = true; + /** Default constructor. */ public RefreshableTokenSource() {} + /** + * Constructor with initial token. + * + * @param token The initial token to use. + */ public RefreshableTokenSource(Token token) { this.token = token; } + /** + * Enable or disable asynchronous token refresh. + * + * @param enabled true to enable async refresh, false to disable + * @return this instance for chaining + */ public RefreshableTokenSource enableAsyncRefresh(boolean enabled) { this.asyncEnabled = enabled; return this; @@ -49,6 +72,7 @@ public RefreshableTokenSource enableAsyncRefresh(boolean enabled) { /** * Helper method implementing OAuth token refresh. * + * @param hc The HTTP client to use for the request. * @param clientId The client ID to authenticate with. * @param clientSecret The client secret to authenticate with. * @param tokenUrl The authorization URL for fetching tokens. @@ -56,6 +80,7 @@ public RefreshableTokenSource enableAsyncRefresh(boolean enabled) { * @param headers Additional headers. * @param position The position of the authentication parameters in the request. * @return The newly fetched Token. + * @throws DatabricksException if the refresh fails */ protected static Token retrieveToken( HttpClient hc, @@ -99,6 +124,12 @@ protected static Token retrieveToken( protected abstract Token refresh(); + /** + * Get the current token, refreshing if necessary. If async refresh is enabled, may return a stale + * token while a refresh is in progress. + * + * @return The current valid token + */ public synchronized Token getToken() { if (!asyncEnabled) { return getTokenBlocking(); @@ -106,6 +137,11 @@ public synchronized Token getToken() { return getTokenAsync(); } + /** + * Determine the state of the current token (fresh, stale, or expired). + * + * @return The token state + */ protected TokenState getTokenState() { if (token == null) { return TokenState.EXPIRED; @@ -120,6 +156,11 @@ protected TokenState getTokenState() { return TokenState.FRESH; } + /** + * Get the current token, blocking to refresh if expired. + * + * @return The current valid token + */ protected synchronized Token getTokenBlocking() { refreshInProgress = false; TokenState state = getTokenState(); @@ -129,6 +170,7 @@ protected synchronized Token getTokenBlocking() { try { Token newToken = refresh(); token = newToken; + lastRefreshSucceeded = true; return newToken; } catch (Exception e) { lastRefreshSucceeded = false; @@ -136,6 +178,12 @@ protected synchronized Token getTokenBlocking() { } } + /** + * Get the current token, possibly triggering an async refresh if stale. If the token is expired, + * blocks to refresh. + * + * @return The current valid or stale token + */ protected Token getTokenAsync() { TokenState state; Token currentToken; @@ -147,23 +195,29 @@ protected Token getTokenAsync() { return currentToken; } if (state == TokenState.STALE) { + // Trigger background refresh, return current token triggerAsyncRefresh(); return token; } else { + // Token is expired, block to refresh return getTokenBlocking(); } } + /** + * Trigger an asynchronous refresh of the token if not already in progress and last refresh + * succeeded. + */ protected synchronized void triggerAsyncRefresh() { if (!refreshInProgress && lastRefreshSucceeded) { refreshInProgress = true; CompletableFuture.runAsync( () -> { try { + // Attempt to refresh the token in the background Token newToken = refresh(); synchronized (this) { token = newToken; - lastRefreshSucceeded = true; refreshInProgress = false; } } catch (Exception e) { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index 15c7768dd..be177cc03 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -8,34 +8,41 @@ import java.time.ZoneId; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class RefreshableTokenSourceTest { - private static class MutableBoolean { - boolean value = false; - } private static final String TOKEN_TYPE = "Bearer"; private static final String INITIAL_TOKEN = "initial-token"; private static final String REFRESHED_TOKEN = "refreshed-token"; - private static final String EXPIRED_TOKEN = "expired-token"; - private static final String FRESH_TOKEN = "fresh-token"; private static final Instant FIXED_INSTANT = Instant.parse("2023-10-18T12:00:00.00Z"); private static final ZoneId ZONE_ID = ZoneId.of("UTC"); - @Test - void testAsyncRefresh() throws Exception { + @ParameterizedTest(name = "{0}") + @MethodSource("provideAsyncRefreshScenarios") + void testAsyncRefreshParametrized( + String testName, + long minutesUntilExpiry, + boolean asyncEnabled, + boolean expectRefresh, + boolean expectRefreshedToken) + throws Exception { FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); - - // Token expires in 2 minutes (less than the default stale duration of 3 minutes) Token initialToken = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); - - CountDownLatch refreshCalled = new CountDownLatch(1); + new Token( + INITIAL_TOKEN, + TOKEN_TYPE, + null, + currentTime.plusMinutes(minutesUntilExpiry), + fakeClock); Token refreshedToken = new Token(REFRESHED_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); + CountDownLatch refreshCalled = new CountDownLatch(1); - // Subclass with a refresh() that signals when called RefreshableTokenSource source = new RefreshableTokenSource(initialToken) { @Override @@ -43,119 +50,74 @@ protected Token refresh() { refreshCalled.countDown(); return refreshedToken; } - }.enableAsyncRefresh(true); + }.enableAsyncRefresh(asyncEnabled); - // First call should return the stale token and trigger async refresh Token token1 = source.getToken(); - assertEquals( - INITIAL_TOKEN, token1.getAccessToken(), "Should return the stale token immediately"); - - // Wait for async refresh to complete (with timeout) - boolean refreshed = refreshCalled.await(2, TimeUnit.SECONDS); - assertTrue(refreshed, "Async refresh should have been triggered"); - - // After refresh, getToken should return the refreshed token - // (may need to wait a bit for the token to be set) - for (int i = 0; i < 10; i++) { - Token token2 = source.getToken(); - if (REFRESHED_TOKEN.equals(token2.getAccessToken())) { - return; // Success + if (expectRefresh) { + // Wait for async refresh if enabled, otherwise refresh is sync + boolean refreshed = refreshCalled.await(2, TimeUnit.SECONDS); + assertTrue(refreshed, "Refresh should have been triggered"); + } else { + assertEquals(1, refreshCalled.getCount(), "Refresh should NOT have been triggered"); + } + if (expectRefreshedToken) { + // Wait for async to complete if needed + for (int i = 0; i < 10; i++) { + Token token2 = source.getToken(); + if (REFRESHED_TOKEN.equals(token2.getAccessToken())) { + return; // Success + } + Thread.sleep(100); } - Thread.sleep(100); + fail("Token was not refreshed as expected"); + } else { + assertEquals(INITIAL_TOKEN, token1.getAccessToken(), "Should return the initial token"); } - fail("Token was not refreshed asynchronously"); } - @Test - void testSyncRefreshWhenExpired() { - FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); - LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); - - Token expiredToken = - new Token(EXPIRED_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); - Token refreshedToken = - new Token(REFRESHED_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); - - MutableBoolean refreshCalled = new MutableBoolean(); - RefreshableTokenSource source = - new RefreshableTokenSource(expiredToken) { - @Override - protected Token refresh() { - refreshCalled.value = true; - return refreshedToken; - } - }.enableAsyncRefresh(true); - - // Should call refresh synchronously and return the refreshed token - Token token = source.getToken(); - assertTrue(refreshCalled.value, "refresh() should be called synchronously for expired token"); - assertEquals(REFRESHED_TOKEN, token.getAccessToken(), "Should return the refreshed token"); + private static Stream provideAsyncRefreshScenarios() { + return Stream.of( + Arguments.of("Fresh token, async enabled", 10, true, false, false), + Arguments.of("Stale token, async enabled", 1, true, true, true), + Arguments.of("Expired token, async enabled", -1, true, true, true), + Arguments.of("Fresh token, async disabled", 10, false, false, false), + Arguments.of("Stale token, async disabled", 1, false, false, false), + Arguments.of("Expired token, async disabled", -1, false, true, true)); } + /** + * This test verifies that if an asynchronous token refresh fails, the next refresh attempt is forced to be synchronous. + * It ensures that after an async failure, the system does not repeatedly attempt async refreshes while the token is stale, + * and only performs a synchronous refresh when the token is expired. After a successful sync refresh, async refreshes resume as normal. + */ @Test - void testNoRefreshWhenTokenIsFresh() { + void testAsyncRefreshFailureFallback() throws Exception { FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); - - Token freshToken = - new Token(FRESH_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); - - RefreshableTokenSource source = - new RefreshableTokenSource(freshToken) { - @Override - protected Token refresh() { - fail("refresh() should not be called when token is fresh"); - return null; - } - }; - - // Should return the fresh token and never call refresh - Token token = source.getToken(); - assertEquals(FRESH_TOKEN, token.getAccessToken(), "Should return the fresh token"); - } - - @Test - void testRefreshThrowsException() { - FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); - LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); - - Token expiredToken = - new Token(EXPIRED_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); - - RefreshableTokenSource source = - new RefreshableTokenSource(expiredToken) { - @Override - protected Token refresh() { - throw new RuntimeException("Simulated refresh failure"); - } - }; - - RuntimeException thrown = assertThrows( - RuntimeException.class, - source::getToken, - "getToken() should propagate exception from refresh() when token is expired"); - assertEquals("Simulated refresh failure", thrown.getMessage()); - } - - @Test - void testFailedAsyncRefreshForcesNextRefreshToBeSynchronous() throws Exception { - FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); - LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); - // Token is stale (expires in 2 minutes) - Token staleToken = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); + Token staleToken = + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); class TestSource extends RefreshableTokenSource { int refreshCallCount = 0; - boolean failFirst = true; - TestSource(Token token) { super(token); } + boolean isFirstRefresh = true; + + TestSource(Token token) { + super(token); + } + @Override protected Token refresh() { refreshCallCount++; - if (failFirst) { - failFirst = false; + if (isFirstRefresh) { + isFirstRefresh = false; throw new RuntimeException("Simulated async failure"); } - throw new RuntimeException("Simulated sync failure"); + return new Token( + REFRESHED_TOKEN, + TOKEN_TYPE, + null, + LocalDateTime.now(fakeClock.getClock()).plusMinutes(10), + fakeClock); } } @@ -164,24 +126,40 @@ protected Token refresh() { // First call triggers async refresh, which fails source.getToken(); - Thread.sleep(300); // Give time for async refresh to run - assertEquals(1, source.refreshCallCount, "refresh() should have been called once (async, failed)"); + Thread.sleep(300); + assertEquals( + 1, source.refreshCallCount, "refresh() should have been called once (async, failed)"); - // Token is still stale, so next call should NOT trigger another refresh + // Token is still stale, so next call should NOT trigger another refresh since the last refresh + // failed source.getToken(); Thread.sleep(200); - assertEquals(1, source.refreshCallCount, "refresh() should NOT be called again while stale after async failure"); + assertEquals( + 1, + source.refreshCallCount, + "refresh() should NOT be called again while stale after async failure"); // Advance the clock so the token is now expired - source.token = new Token( - INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); - - // Now getToken() should call refresh synchronously and throw - RuntimeException thrown = assertThrows( - RuntimeException.class, - source::getToken, - "getToken() should call refresh synchronously and propagate exception when expired"); - assertEquals("Simulated sync failure", thrown.getMessage()); - assertEquals(2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); + source.token = + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); + + // Now getToken() should call refresh synchronously and return the refreshed token + Token token = source.getToken(); + assertEquals( + REFRESHED_TOKEN, + token.getAccessToken(), + "Should return the refreshed token after sync refresh"); + assertEquals( + 2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); + + // Make the token stale again and trigger async refresh since the last sync refresh succeeded + source.token = + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); + source.getToken(); + Thread.sleep(300); + assertEquals( + 3, + source.refreshCallCount, + "refresh() should have been called again asynchronously after making token stale"); } } From 5a37c59833ed2261251b03b26605ea1fdada55e4 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 30 May 2025 12:43:04 +0000 Subject: [PATCH 04/49] Clean up comments --- .../core/oauth/RefreshableTokenSource.java | 125 ++++++++++-------- .../oauth/RefreshableTokenSourceTest.java | 8 +- 2 files changed, 77 insertions(+), 56 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 74097906c..e1168222f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -24,7 +24,8 @@ public abstract class RefreshableTokenSource implements TokenSource { /** * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: - * Token is valid but will expire soon. EXPIRED: Token has expired and must be refreshed. + * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: + * Token has expired and must be refreshed using a blocking call. */ private enum TokenState { FRESH, @@ -70,63 +71,21 @@ public RefreshableTokenSource enableAsyncRefresh(boolean enabled) { } /** - * Helper method implementing OAuth token refresh. + * Refresh the OAuth token. Subclasses must implement this to define how the token is refreshed. * - * @param hc The HTTP client to use for the request. - * @param clientId The client ID to authenticate with. - * @param clientSecret The client secret to authenticate with. - * @param tokenUrl The authorization URL for fetching tokens. - * @param params Additional request parameters. - * @param headers Additional headers. - * @param position The position of the authentication parameters in the request. - * @return The newly fetched Token. - * @throws DatabricksException if the refresh fails + *

This method may throw an exception if the token cannot be refreshed. The specific exception + * type depends on the implementation. + * + * @return The newly refreshed Token. */ - protected static Token retrieveToken( - HttpClient hc, - String clientId, - String clientSecret, - String tokenUrl, - Map params, - Map headers, - AuthParameterPosition position) { - switch (position) { - case BODY: - if (clientId != null) { - params.put("client_id", clientId); - } - if (clientSecret != null) { - params.put("client_secret", clientSecret); - } - break; - case HEADER: - String authHeaderValue = - "Basic " - + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); - headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue); - break; - } - headers.put("Content-Type", "application/x-www-form-urlencoded"); - Request req = new Request("POST", tokenUrl, FormRequest.wrapValuesInList(params)); - req.withHeaders(headers); - try { - ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); - OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); - if (resp.getErrorCode() != null) { - throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); - } - LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS); - return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); - } catch (Exception e) { - throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); - } - } - protected abstract Token refresh(); /** - * Get the current token, refreshing if necessary. If async refresh is enabled, may return a stale - * token while a refresh is in progress. + * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a + * stale token while a refresh is in progress. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. * * @return The current valid token */ @@ -159,6 +118,9 @@ protected TokenState getTokenState() { /** * Get the current token, blocking to refresh if expired. * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * * @return The current valid token */ protected synchronized Token getTokenBlocking() { @@ -182,6 +144,9 @@ protected synchronized Token getTokenBlocking() { * Get the current token, possibly triggering an async refresh if stale. If the token is expired, * blocks to refresh. * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * * @return The current valid or stale token */ protected Token getTokenAsync() { @@ -229,4 +194,58 @@ protected synchronized void triggerAsyncRefresh() { }); } } + + /** + * Helper method implementing OAuth token refresh. + * + * @param hc The HTTP client to use for the request. + * @param clientId The client ID to authenticate with. + * @param clientSecret The client secret to authenticate with. + * @param tokenUrl The authorization URL for fetching tokens. + * @param params Additional request parameters. + * @param headers Additional headers. + * @param position The position of the authentication parameters in the request. + * @return The newly fetched Token. + * @throws DatabricksException if the refresh fails + * @throws IllegalArgumentException if the OAuth response contains an error + */ + protected static Token retrieveToken( + HttpClient hc, + String clientId, + String clientSecret, + String tokenUrl, + Map params, + Map headers, + AuthParameterPosition position) { + switch (position) { + case BODY: + if (clientId != null) { + params.put("client_id", clientId); + } + if (clientSecret != null) { + params.put("client_secret", clientSecret); + } + break; + case HEADER: + String authHeaderValue = + "Basic " + + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); + headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue); + break; + } + headers.put("Content-Type", "application/x-www-form-urlencoded"); + Request req = new Request("POST", tokenUrl, FormRequest.wrapValuesInList(params)); + req.withHeaders(headers); + try { + ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); + OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); + if (resp.getErrorCode() != null) { + throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); + } + LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS); + return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); + } catch (Exception e) { + throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index be177cc03..e4da02959 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -86,9 +86,11 @@ private static Stream provideAsyncRefreshScenarios() { } /** - * This test verifies that if an asynchronous token refresh fails, the next refresh attempt is forced to be synchronous. - * It ensures that after an async failure, the system does not repeatedly attempt async refreshes while the token is stale, - * and only performs a synchronous refresh when the token is expired. After a successful sync refresh, async refreshes resume as normal. + * This test verifies that if an asynchronous token refresh fails, the next refresh attempt is + * forced to be synchronous. It ensures that after an async failure, the system does not + * repeatedly attempt async refreshes while the token is stale, and only performs a synchronous + * refresh when the token is expired. After a successful sync refresh, async refreshes resume as + * normal. */ @Test void testAsyncRefreshFailureFallback() throws Exception { From f5e4f8abb57df8916e7c2ae317f2447c8f2a0a56 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 30 May 2025 12:50:04 +0000 Subject: [PATCH 05/49] Add Javadoc to Token.java --- .../com/databricks/sdk/core/oauth/Token.java | 37 +++++++++++++++++++ .../databricks/sdk/core/oauth/TokenTest.java | 20 ++++++++++ 2 files changed, 57 insertions(+) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index 5aaa012da..ee5df2f5f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -66,6 +66,12 @@ public Token( this.clockSupplier = clockSupplier; } + /** + * Checks if the token is expired. Tokens are considered expired 40 seconds before their actual + * expiry time to account for Azure Databricks rejecting tokens that expire in 30 seconds or less. + * + * @return true if the token is expired or about to expire, false otherwise + */ public boolean isExpired() { if (expiry == null) { return false; @@ -77,26 +83,57 @@ public boolean isExpired() { return potentiallyExpired.isBefore(now); } + /** + * Checks if the token is valid. A token is valid if it has a non-null access token and is not + * expired. + * + * @return true if the token is valid, false otherwise + */ public boolean isValid() { return accessToken != null && !isExpired(); } + /** + * Returns the type of the token (e.g., "Bearer"). + * + * @return the token type + */ public String getTokenType() { return tokenType; } + /** + * Returns the refresh token, if available. May be null for non-refreshable tokens. + * + * @return the refresh token or null + */ public String getRefreshToken() { return refreshToken; } + /** + * Returns the access token string. + * + * @return the access token + */ public String getAccessToken() { return accessToken; } + /** + * Returns the expiry time of the token as a LocalDateTime. + * + * @return the expiry time + */ public LocalDateTime getExpiry() { return this.expiry; } + /** + * Returns the remaining lifetime of the token as a Duration. + * + * @return the duration between now and the token's expiry + */ public Duration getLifetime() { return Duration.between(LocalDateTime.now(clockSupplier.getClock()), this.expiry); } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java index 2d87a32c2..91d687b2f 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java @@ -71,4 +71,24 @@ void expiredToken() { assertTrue(token.isExpired()); assertFalse(token.isValid()); } + + @Test + void tokenLifetimeInFuture() { + Token token = + new Token(accessToken, tokenType, currentLocalDateTime.plusMinutes(10), fakeClockSupplier); + assertEquals(java.time.Duration.ofMinutes(10), token.getLifetime()); + } + + @Test + void tokenLifetimeExpired() { + Token token = + new Token(accessToken, tokenType, currentLocalDateTime.minusMinutes(2), fakeClockSupplier); + assertEquals(java.time.Duration.ofMinutes(-2), token.getLifetime()); + } + + @Test + void tokenLifetimeZero() { + Token token = new Token(accessToken, tokenType, currentLocalDateTime, fakeClockSupplier); + assertEquals(java.time.Duration.ZERO, token.getLifetime()); + } } From 3e651092f882b5f6ca41a2b354248fa42f21dfc2 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 30 May 2025 14:00:19 +0000 Subject: [PATCH 06/49] Add a token expiry buffer field --- .../sdk/core/oauth/RefreshableTokenSource.java | 18 ++++++++++++++++-- .../core/oauth/RefreshableTokenSourceTest.java | 4 ++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index e1168222f..291da9e84 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -42,6 +42,8 @@ private enum TokenState { private boolean asyncEnabled = false; /** Duration before expiry to consider a token as 'stale'. */ private Duration staleDuration = DEFAULT_STALE_DURATION; + /** Additional buffer before expiry to consider a token as expired. */ + private Duration expiryBuffer = Duration.ZERO; /** Whether a refresh is currently in progress (for async refresh). */ private boolean refreshInProgress = false; /** Whether the last refresh attempt succeeded. */ @@ -65,11 +67,23 @@ public RefreshableTokenSource(Token token) { * @param enabled true to enable async refresh, false to disable * @return this instance for chaining */ - public RefreshableTokenSource enableAsyncRefresh(boolean enabled) { + public RefreshableTokenSource withAsyncRefresh(boolean enabled) { this.asyncEnabled = enabled; return this; } + /** + * Set the expiry buffer. If the token's lifetime is less than this buffer, it is considered + * expired. + * + * @param buffer the expiry buffer duration + * @return this instance for chaining + */ + public RefreshableTokenSource withExpiryBuffer(Duration buffer) { + this.expiryBuffer = buffer; + return this; + } + /** * Refresh the OAuth token. Subclasses must implement this to define how the token is refreshed. * @@ -106,7 +120,7 @@ protected TokenState getTokenState() { return TokenState.EXPIRED; } Duration lifeTime = token.getLifetime(); - if (lifeTime.compareTo(Duration.ZERO) <= 0) { + if (lifeTime.compareTo(expiryBuffer) <= 0) { return TokenState.EXPIRED; } if (lifeTime.compareTo(staleDuration) <= 0) { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index e4da02959..8d991e003 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -50,7 +50,7 @@ protected Token refresh() { refreshCalled.countDown(); return refreshedToken; } - }.enableAsyncRefresh(asyncEnabled); + }.withAsyncRefresh(asyncEnabled); Token token1 = source.getToken(); if (expectRefresh) { @@ -124,7 +124,7 @@ protected Token refresh() { } TestSource source = new TestSource(staleToken); - source.enableAsyncRefresh(true); + source.withAsyncRefresh(true); // First call triggers async refresh, which fails source.getToken(); From dfce414f347dbd7bd995012eb94eda1107907099 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 2 Jun 2025 15:39:23 +0000 Subject: [PATCH 07/49] Fix for comments --- .../core/oauth/RefreshableTokenSource.java | 50 ++++++++----------- .../com/databricks/sdk/core/oauth/Token.java | 28 ----------- .../core/oauth/DataPlaneTokenSourceTest.java | 1 - .../oauth/DatabricksOAuthTokenSourceTest.java | 1 - .../sdk/core/oauth/FileTokenCacheTest.java | 18 ------- .../databricks/sdk/core/oauth/TokenTest.java | 26 ---------- 6 files changed, 22 insertions(+), 102 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 291da9e84..5270bf6cd 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -36,20 +36,20 @@ private enum TokenState { // Default duration before expiry to consider a token as 'stale'. private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); - /** The current OAuth token. May be null if not yet fetched. */ + // The current OAuth token. May be null if not yet fetched. protected Token token; - /** Whether asynchronous refresh is enabled. */ + // Whether asynchronous refresh is enabled. private boolean asyncEnabled = false; - /** Duration before expiry to consider a token as 'stale'. */ + // Duration before expiry to consider a token as 'stale'. private Duration staleDuration = DEFAULT_STALE_DURATION; - /** Additional buffer before expiry to consider a token as expired. */ - private Duration expiryBuffer = Duration.ZERO; - /** Whether a refresh is currently in progress (for async refresh). */ + // Additional buffer before expiry to consider a token as expired. + private Duration expiryBuffer = Duration.ofSeconds(40); + // Whether a refresh is currently in progress (for async refresh). private boolean refreshInProgress = false; - /** Whether the last refresh attempt succeeded. */ + // Whether the last refresh attempt succeeded. private boolean lastRefreshSucceeded = true; - /** Default constructor. */ + /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ public RefreshableTokenSource() {} /** @@ -103,7 +103,7 @@ public RefreshableTokenSource withExpiryBuffer(Duration buffer) { * * @return The current valid token */ - public synchronized Token getToken() { + public Token getToken() { if (!asyncEnabled) { return getTokenBlocking(); } @@ -143,15 +143,10 @@ protected synchronized Token getTokenBlocking() { if (state != TokenState.EXPIRED) { return token; } - try { - Token newToken = refresh(); - token = newToken; - lastRefreshSucceeded = true; - return newToken; - } catch (Exception e) { - lastRefreshSucceeded = false; - throw e; - } + lastRefreshSucceeded = false; + token = refresh(); // May throw an exception + lastRefreshSucceeded = true; + return token; } /** @@ -170,16 +165,15 @@ protected Token getTokenAsync() { state = getTokenState(); currentToken = token; } - if (state == TokenState.FRESH) { - return currentToken; - } - if (state == TokenState.STALE) { - // Trigger background refresh, return current token - triggerAsyncRefresh(); - return token; - } else { - // Token is expired, block to refresh - return getTokenBlocking(); + switch (state) { + case FRESH: + return currentToken; + case STALE: + triggerAsyncRefresh(); + return currentToken; + case EXPIRED: + default: + return getTokenBlocking(); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index ee5df2f5f..59afddd84 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -6,7 +6,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.time.Duration; import java.time.LocalDateTime; -import java.time.temporal.ChronoUnit; import java.util.Objects; public class Token { @@ -66,33 +65,6 @@ public Token( this.clockSupplier = clockSupplier; } - /** - * Checks if the token is expired. Tokens are considered expired 40 seconds before their actual - * expiry time to account for Azure Databricks rejecting tokens that expire in 30 seconds or less. - * - * @return true if the token is expired or about to expire, false otherwise - */ - public boolean isExpired() { - if (expiry == null) { - return false; - } - // Azure Databricks rejects tokens that expire in 30 seconds or less, - // so we refresh the token 40 seconds before it expires. - LocalDateTime potentiallyExpired = expiry.minus(40, ChronoUnit.SECONDS); - LocalDateTime now = LocalDateTime.now(clockSupplier.getClock()); - return potentiallyExpired.isBefore(now); - } - - /** - * Checks if the token is valid. A token is valid if it has a non-null access token and is not - * expired. - * - * @return true if the token is valid, false otherwise - */ - public boolean isValid() { - return accessToken != null && !isExpired(); - } - /** * Returns the type of the token (e.g., "Bearer"). * diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java index 5887c4ee1..d50452b2f 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -196,7 +196,6 @@ void testDataPlaneTokenSource( assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); assertEquals(expectedToken.getTokenType(), token.getTokenType()); assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); - assertTrue(token.isValid()); } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 8217179f2..ee226cd42 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -330,7 +330,6 @@ void testTokenSource(TestCase testCase) { assertEquals(TOKEN, token.getAccessToken()); assertEquals(TOKEN_TYPE, token.getTokenType()); assertEquals(REFRESH_TOKEN, token.getRefreshToken()); - assertFalse(token.isExpired()); // Verify correct audience was used verify(testCase.idTokenSource, atLeastOnce()).getIDToken(testCase.expectedAudience); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java index ede6cfd11..fb89e5b81 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java @@ -54,24 +54,6 @@ void testSaveAndLoadToken() { assertEquals("access-token", loadedToken.getAccessToken()); assertEquals("Bearer", loadedToken.getTokenType()); assertEquals("refresh-token", loadedToken.getRefreshToken()); - assertFalse(loadedToken.isExpired(), "Token should not be expired"); - } - - @Test - void testTokenExpiry() { - // Create an expired token - LocalDateTime pastTime = LocalDateTime.now().minusHours(1); - Token expiredToken = new Token("access-token", "Bearer", "refresh-token", pastTime); - - // Verify it's marked as expired - assertTrue(expiredToken.isExpired(), "Token should be expired"); - - // Create a valid token - LocalDateTime futureTime = LocalDateTime.now().plusMinutes(30); - Token validToken = new Token("access-token", "Bearer", "refresh-token", futureTime); - - // Verify it's not marked as expired - assertFalse(validToken.isExpired(), "Token should not be expired"); } @Test diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java index 91d687b2f..ed0e4b3bd 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java @@ -30,7 +30,6 @@ void createNonRefreshableToken() { assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertNull(token.getRefreshToken()); - assertTrue(token.isValid()); } @Test @@ -45,31 +44,6 @@ void createRefreshableToken() { assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertEquals(refreshToken, token.getRefreshToken()); - assertTrue(token.isValid()); - } - - @Test - void tokenExpiryMoreThan40Seconds() { - Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusSeconds(50), fakeClockSupplier); - assertFalse(token.isExpired()); - assertTrue(token.isValid()); - } - - @Test - void tokenExpiryLessThan40Seconds() { - Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusSeconds(30), fakeClockSupplier); - assertTrue(token.isExpired()); - assertFalse(token.isValid()); - } - - @Test - void expiredToken() { - Token token = - new Token(accessToken, tokenType, currentLocalDateTime.minusSeconds(10), fakeClockSupplier); - assertTrue(token.isExpired()); - assertFalse(token.isValid()); } @Test From 5bd4215a686793a2eb438a6049c829882a6f4490 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 12:01:48 +0000 Subject: [PATCH 08/49] Update tests --- .../core/oauth/RefreshableTokenSource.java | 24 ++++- .../com/databricks/sdk/core/oauth/Token.java | 34 +------ .../oauth/RefreshableTokenSourceTest.java | 90 +++++++------------ .../databricks/sdk/core/oauth/TokenTest.java | 43 +-------- 4 files changed, 58 insertions(+), 133 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 5270bf6cd..50e596348 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -5,6 +5,8 @@ import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.utils.ClockSupplier; +import com.databricks.sdk.core.utils.SystemClockSupplier; import java.time.Duration; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; @@ -48,6 +50,8 @@ private enum TokenState { private boolean refreshInProgress = false; // Whether the last refresh attempt succeeded. private boolean lastRefreshSucceeded = true; + // Clock supplier for current time, for testing purposes. + private ClockSupplier clockSupplier = new SystemClockSupplier(); /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ public RefreshableTokenSource() {} @@ -61,6 +65,17 @@ public RefreshableTokenSource(Token token) { this.token = token; } + /** + * Set the clock supplier for current time. + * + * @param clockSupplier The clock supplier to use. + * @return this instance for chaining + */ + public RefreshableTokenSource withClockSupplier(ClockSupplier clockSupplier) { + this.clockSupplier = clockSupplier; + return this; + } + /** * Enable or disable asynchronous token refresh. * @@ -119,7 +134,8 @@ protected TokenState getTokenState() { if (token == null) { return TokenState.EXPIRED; } - Duration lifeTime = token.getLifetime(); + Duration lifeTime = + Duration.between(LocalDateTime.now(clockSupplier.getClock()), token.getExpiry()); if (lifeTime.compareTo(expiryBuffer) <= 0) { return TokenState.EXPIRED; } @@ -138,7 +154,6 @@ protected TokenState getTokenState() { * @return The current valid token */ protected synchronized Token getTokenBlocking() { - refreshInProgress = false; TokenState state = getTokenState(); if (state != TokenState.EXPIRED) { return token; @@ -160,7 +175,7 @@ protected synchronized Token getTokenBlocking() { */ protected Token getTokenAsync() { TokenState state; - Token currentToken; + Token currentToken = token; synchronized (this) { state = getTokenState(); currentToken = token; @@ -172,8 +187,9 @@ protected Token getTokenAsync() { triggerAsyncRefresh(); return currentToken; case EXPIRED: - default: return getTokenBlocking(); + default: + throw new IllegalStateException("Invalid token state: " + state); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index 59afddd84..296809f61 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -1,10 +1,7 @@ package com.databricks.sdk.core.oauth; -import com.databricks.sdk.core.utils.ClockSupplier; -import com.databricks.sdk.core.utils.SystemClockSupplier; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import java.time.Duration; import java.time.LocalDateTime; import java.util.Objects; @@ -24,17 +21,9 @@ public class Token { */ @JsonProperty private LocalDateTime expiry; - private final ClockSupplier clockSupplier; - /** Constructor for non-refreshable tokens (e.g. M2M). */ public Token(String accessToken, String tokenType, LocalDateTime expiry) { - this(accessToken, tokenType, null, expiry, new SystemClockSupplier()); - } - - /** Constructor for non-refreshable tokens (e.g. M2M) with ClockSupplier */ - public Token( - String accessToken, String tokenType, LocalDateTime expiry, ClockSupplier clockSupplier) { - this(accessToken, tokenType, null, expiry, clockSupplier); + this(accessToken, tokenType, null, expiry); } /** Constructor for refreshable tokens. */ @@ -44,25 +33,13 @@ public Token( @JsonProperty("tokenType") String tokenType, @JsonProperty("refreshToken") String refreshToken, @JsonProperty("expiry") LocalDateTime expiry) { - this(accessToken, tokenType, refreshToken, expiry, new SystemClockSupplier()); - } - - /** Constructor for refreshable tokens with ClockSupplier. */ - public Token( - String accessToken, - String tokenType, - String refreshToken, - LocalDateTime expiry, - ClockSupplier clockSupplier) { Objects.requireNonNull(accessToken, "accessToken must be defined"); Objects.requireNonNull(tokenType, "tokenType must be defined"); Objects.requireNonNull(expiry, "expiry must be defined"); - Objects.requireNonNull(clockSupplier, "clockSupplier must be defined"); this.accessToken = accessToken; this.tokenType = tokenType; this.refreshToken = refreshToken; this.expiry = expiry; - this.clockSupplier = clockSupplier; } /** @@ -100,13 +77,4 @@ public String getAccessToken() { public LocalDateTime getExpiry() { return this.expiry; } - - /** - * Returns the remaining lifetime of the token as a Duration. - * - * @return the duration between now and the token's expiry - */ - public Duration getLifetime() { - return Duration.between(LocalDateTime.now(clockSupplier.getClock()), this.expiry); - } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index 8d991e003..aece009f8 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -2,10 +2,7 @@ import static org.junit.jupiter.api.Assertions.*; -import com.databricks.sdk.core.utils.FakeClockSupplier; -import java.time.Instant; import java.time.LocalDateTime; -import java.time.ZoneId; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; @@ -17,9 +14,20 @@ public class RefreshableTokenSourceTest { private static final String TOKEN_TYPE = "Bearer"; private static final String INITIAL_TOKEN = "initial-token"; - private static final String REFRESHED_TOKEN = "refreshed-token"; - private static final Instant FIXED_INSTANT = Instant.parse("2023-10-18T12:00:00.00Z"); - private static final ZoneId ZONE_ID = ZoneId.of("UTC"); + private static final String REFRESH_TOKEN = "refreshed-token"; + private static final long FRESH_TIME_MINUTES = 10; + private static final long STALE_TIME_MINUTES = 1; + private static final long EXPIRED_TIME_MINUTES = -1; + + private static Stream provideAsyncRefreshScenarios() { + return Stream.of( + Arguments.of("Fresh token, async enabled", FRESH_TIME_MINUTES, true, false, false), + Arguments.of("Stale token, async enabled", STALE_TIME_MINUTES, true, true, false), + Arguments.of("Expired token, async enabled", EXPIRED_TIME_MINUTES, true, true, true), + Arguments.of("Fresh token, async disabled", FRESH_TIME_MINUTES, false, false, false), + Arguments.of("Stale token, async disabled", STALE_TIME_MINUTES, false, false, false), + Arguments.of("Expired token, async disabled", EXPIRED_TIME_MINUTES, false, true, true)); + } @ParameterizedTest(name = "{0}") @MethodSource("provideAsyncRefreshScenarios") @@ -30,17 +38,12 @@ void testAsyncRefreshParametrized( boolean expectRefresh, boolean expectRefreshedToken) throws Exception { - FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); - LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); + Token initialToken = new Token( - INITIAL_TOKEN, - TOKEN_TYPE, - null, - currentTime.plusMinutes(minutesUntilExpiry), - fakeClock); + INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(minutesUntilExpiry)); Token refreshedToken = - new Token(REFRESHED_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(10), fakeClock); + new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); CountDownLatch refreshCalled = new CountDownLatch(1); RefreshableTokenSource source = @@ -48,43 +51,27 @@ void testAsyncRefreshParametrized( @Override protected Token refresh() { refreshCalled.countDown(); + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } return refreshedToken; } }.withAsyncRefresh(asyncEnabled); - Token token1 = source.getToken(); - if (expectRefresh) { - // Wait for async refresh if enabled, otherwise refresh is sync - boolean refreshed = refreshCalled.await(2, TimeUnit.SECONDS); - assertTrue(refreshed, "Refresh should have been triggered"); - } else { - assertEquals(1, refreshCalled.getCount(), "Refresh should NOT have been triggered"); - } + Token token = source.getToken(); + + boolean refreshed = refreshCalled.await(1, TimeUnit.SECONDS); + assertEquals(expectRefresh, refreshed, "Refresh should have been triggered"); + if (expectRefreshedToken) { - // Wait for async to complete if needed - for (int i = 0; i < 10; i++) { - Token token2 = source.getToken(); - if (REFRESHED_TOKEN.equals(token2.getAccessToken())) { - return; // Success - } - Thread.sleep(100); - } - fail("Token was not refreshed as expected"); + assertEquals(REFRESH_TOKEN, token.getAccessToken(), "Token was not refreshed as expected"); } else { - assertEquals(INITIAL_TOKEN, token1.getAccessToken(), "Should return the initial token"); + assertEquals(INITIAL_TOKEN, token.getAccessToken(), "Should return the initial token"); } } - private static Stream provideAsyncRefreshScenarios() { - return Stream.of( - Arguments.of("Fresh token, async enabled", 10, true, false, false), - Arguments.of("Stale token, async enabled", 1, true, true, true), - Arguments.of("Expired token, async enabled", -1, true, true, true), - Arguments.of("Fresh token, async disabled", 10, false, false, false), - Arguments.of("Stale token, async disabled", 1, false, false, false), - Arguments.of("Expired token, async disabled", -1, false, true, true)); - } - /** * This test verifies that if an asynchronous token refresh fails, the next refresh attempt is * forced to be synchronous. It ensures that after an async failure, the system does not @@ -94,10 +81,8 @@ private static Stream provideAsyncRefreshScenarios() { */ @Test void testAsyncRefreshFailureFallback() throws Exception { - FakeClockSupplier fakeClock = new FakeClockSupplier(FIXED_INSTANT, ZONE_ID); - LocalDateTime currentTime = LocalDateTime.now(fakeClock.getClock()); Token staleToken = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); class TestSource extends RefreshableTokenSource { int refreshCallCount = 0; @@ -114,12 +99,7 @@ protected Token refresh() { isFirstRefresh = false; throw new RuntimeException("Simulated async failure"); } - return new Token( - REFRESHED_TOKEN, - TOKEN_TYPE, - null, - LocalDateTime.now(fakeClock.getClock()).plusMinutes(10), - fakeClock); + return new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); } } @@ -142,21 +122,19 @@ protected Token refresh() { "refresh() should NOT be called again while stale after async failure"); // Advance the clock so the token is now expired - source.token = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.minusMinutes(1), fakeClock); + source.token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().minusMinutes(1)); // Now getToken() should call refresh synchronously and return the refreshed token Token token = source.getToken(); assertEquals( - REFRESHED_TOKEN, + REFRESH_TOKEN, token.getAccessToken(), "Should return the refreshed token after sync refresh"); assertEquals( 2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); // Make the token stale again and trigger async refresh since the last sync refresh succeeded - source.token = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, currentTime.plusMinutes(2), fakeClock); + source.token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); source.getToken(); Thread.sleep(300); assertEquals( diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java index ed0e4b3bd..d82dfab82 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java @@ -2,10 +2,7 @@ import static org.junit.jupiter.api.Assertions.*; -import com.databricks.sdk.core.utils.FakeClockSupplier; -import java.time.Instant; import java.time.LocalDateTime; -import java.time.ZoneId; import org.junit.jupiter.api.Test; class TokenTest { @@ -13,20 +10,11 @@ class TokenTest { private static final String accessToken = "testAccessToken"; private static final String refreshToken = "testRefreshToken"; private static final String tokenType = "testTokenType"; - private final LocalDateTime currentLocalDateTime; - private final FakeClockSupplier fakeClockSupplier; - - TokenTest() { - Instant instant = Instant.parse("2023-10-18T12:00:00.00Z"); - ZoneId zoneId = ZoneId.of("UTC"); - fakeClockSupplier = new FakeClockSupplier(instant, zoneId); - currentLocalDateTime = LocalDateTime.now(fakeClockSupplier.getClock()); - } + private static final LocalDateTime currentLocalDateTime = LocalDateTime.now(); @Test void createNonRefreshableToken() { - Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusMinutes(5), fakeClockSupplier); + Token token = new Token(accessToken, tokenType, currentLocalDateTime.plusMinutes(5)); assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertNull(token.getRefreshToken()); @@ -35,34 +23,9 @@ void createNonRefreshableToken() { @Test void createRefreshableToken() { Token token = - new Token( - accessToken, - tokenType, - refreshToken, - currentLocalDateTime.plusMinutes(5), - fakeClockSupplier); + new Token(accessToken, tokenType, refreshToken, currentLocalDateTime.plusMinutes(5)); assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertEquals(refreshToken, token.getRefreshToken()); } - - @Test - void tokenLifetimeInFuture() { - Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusMinutes(10), fakeClockSupplier); - assertEquals(java.time.Duration.ofMinutes(10), token.getLifetime()); - } - - @Test - void tokenLifetimeExpired() { - Token token = - new Token(accessToken, tokenType, currentLocalDateTime.minusMinutes(2), fakeClockSupplier); - assertEquals(java.time.Duration.ofMinutes(-2), token.getLifetime()); - } - - @Test - void tokenLifetimeZero() { - Token token = new Token(accessToken, tokenType, currentLocalDateTime, fakeClockSupplier); - assertEquals(java.time.Duration.ZERO, token.getLifetime()); - } } From 93e0bafcb95c16892cf65bf5a7db2cc69287ce76 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 12:14:19 +0000 Subject: [PATCH 09/49] Clean up tests --- .../oauth/RefreshableTokenSourceTest.java | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index aece009f8..11a791bc9 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -21,12 +21,16 @@ public class RefreshableTokenSourceTest { private static Stream provideAsyncRefreshScenarios() { return Stream.of( - Arguments.of("Fresh token, async enabled", FRESH_TIME_MINUTES, true, false, false), - Arguments.of("Stale token, async enabled", STALE_TIME_MINUTES, true, true, false), - Arguments.of("Expired token, async enabled", EXPIRED_TIME_MINUTES, true, true, true), - Arguments.of("Fresh token, async disabled", FRESH_TIME_MINUTES, false, false, false), - Arguments.of("Stale token, async disabled", STALE_TIME_MINUTES, false, false, false), - Arguments.of("Expired token, async disabled", EXPIRED_TIME_MINUTES, false, true, true)); + Arguments.of("Fresh token, async enabled", FRESH_TIME_MINUTES, true, false, INITIAL_TOKEN), + Arguments.of("Stale token, async enabled", STALE_TIME_MINUTES, true, true, INITIAL_TOKEN), + Arguments.of( + "Expired token, async enabled", EXPIRED_TIME_MINUTES, true, true, REFRESH_TOKEN), + Arguments.of( + "Fresh token, async disabled", FRESH_TIME_MINUTES, false, false, INITIAL_TOKEN), + Arguments.of( + "Stale token, async disabled", STALE_TIME_MINUTES, false, false, INITIAL_TOKEN), + Arguments.of( + "Expired token, async disabled", EXPIRED_TIME_MINUTES, false, true, REFRESH_TOKEN)); } @ParameterizedTest(name = "{0}") @@ -36,7 +40,7 @@ void testAsyncRefreshParametrized( long minutesUntilExpiry, boolean asyncEnabled, boolean expectRefresh, - boolean expectRefreshedToken) + String expectedToken) throws Exception { Token initialToken = @@ -64,12 +68,7 @@ protected Token refresh() { boolean refreshed = refreshCalled.await(1, TimeUnit.SECONDS); assertEquals(expectRefresh, refreshed, "Refresh should have been triggered"); - - if (expectRefreshedToken) { - assertEquals(REFRESH_TOKEN, token.getAccessToken(), "Token was not refreshed as expected"); - } else { - assertEquals(INITIAL_TOKEN, token.getAccessToken(), "Should return the initial token"); - } + assertEquals(expectedToken, token.getAccessToken(), "Token value did not match expected"); } /** From 15e221de95db5352512ff63afdd3f1e1056ccfe8 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 12:24:03 +0000 Subject: [PATCH 10/49] Add logging --- .../core/oauth/RefreshableTokenSource.java | 11 +++++++++- .../oauth/RefreshableTokenSourceTest.java | 22 ++++++++----------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 50e596348..24f88bbd0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -14,6 +14,8 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import org.apache.http.HttpHeaders; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * An OAuth TokenSource which can be refreshed. @@ -35,6 +37,7 @@ private enum TokenState { EXPIRED } + private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); // Default duration before expiry to consider a token as 'stale'. private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); @@ -159,7 +162,12 @@ protected synchronized Token getTokenBlocking() { return token; } lastRefreshSucceeded = false; - token = refresh(); // May throw an exception + try { + token = refresh(); // May throw an exception + } catch (Exception e) { + logger.error("Failed to refresh token synchronously", e); + throw e; + } lastRefreshSucceeded = true; return token; } @@ -213,6 +221,7 @@ protected synchronized void triggerAsyncRefresh() { synchronized (this) { lastRefreshSucceeded = false; refreshInProgress = false; + logger.error("Async token refresh failed", e); } } }); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index 11a791bc9..c33e7fda5 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -15,22 +15,18 @@ public class RefreshableTokenSourceTest { private static final String TOKEN_TYPE = "Bearer"; private static final String INITIAL_TOKEN = "initial-token"; private static final String REFRESH_TOKEN = "refreshed-token"; - private static final long FRESH_TIME_MINUTES = 10; - private static final long STALE_TIME_MINUTES = 1; - private static final long EXPIRED_TIME_MINUTES = -1; + private static final long FRESH_MINUTES = 10; + private static final long STALE_MINUTES = 1; + private static final long EXPIRED_MINUTES = -1; private static Stream provideAsyncRefreshScenarios() { return Stream.of( - Arguments.of("Fresh token, async enabled", FRESH_TIME_MINUTES, true, false, INITIAL_TOKEN), - Arguments.of("Stale token, async enabled", STALE_TIME_MINUTES, true, true, INITIAL_TOKEN), - Arguments.of( - "Expired token, async enabled", EXPIRED_TIME_MINUTES, true, true, REFRESH_TOKEN), - Arguments.of( - "Fresh token, async disabled", FRESH_TIME_MINUTES, false, false, INITIAL_TOKEN), - Arguments.of( - "Stale token, async disabled", STALE_TIME_MINUTES, false, false, INITIAL_TOKEN), - Arguments.of( - "Expired token, async disabled", EXPIRED_TIME_MINUTES, false, true, REFRESH_TOKEN)); + Arguments.of("Fresh token, async enabled", FRESH_MINUTES, true, false, INITIAL_TOKEN), + Arguments.of("Stale token, async enabled", STALE_MINUTES, true, true, INITIAL_TOKEN), + Arguments.of("Expired token, async enabled", EXPIRED_MINUTES, true, true, REFRESH_TOKEN), + Arguments.of("Fresh token, async disabled", FRESH_MINUTES, false, false, INITIAL_TOKEN), + Arguments.of("Stale token, async disabled", STALE_MINUTES, false, false, INITIAL_TOKEN), + Arguments.of("Expired token, async disabled", EXPIRED_MINUTES, false, true, REFRESH_TOKEN)); } @ParameterizedTest(name = "{0}") From 7589dab5f40bd1cf86f46984b0b2268d47e02638 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 12:39:02 +0000 Subject: [PATCH 11/49] Performance optimization --- .../core/oauth/RefreshableTokenSource.java | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 24f88bbd0..769b88731 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -42,7 +42,7 @@ private enum TokenState { private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); // The current OAuth token. May be null if not yet fetched. - protected Token token; + protected volatile Token token; // Whether asynchronous refresh is enabled. private boolean asyncEnabled = false; // Duration before expiry to consider a token as 'stale'. @@ -133,12 +133,12 @@ public Token getToken() { * * @return The token state */ - protected TokenState getTokenState() { - if (token == null) { + protected TokenState getTokenState(Token t) { + if (t == null) { return TokenState.EXPIRED; } Duration lifeTime = - Duration.between(LocalDateTime.now(clockSupplier.getClock()), token.getExpiry()); + Duration.between(LocalDateTime.now(clockSupplier.getClock()), t.getExpiry()); if (lifeTime.compareTo(expiryBuffer) <= 0) { return TokenState.EXPIRED; } @@ -157,13 +157,12 @@ protected TokenState getTokenState() { * @return The current valid token */ protected synchronized Token getTokenBlocking() { - TokenState state = getTokenState(); - if (state != TokenState.EXPIRED) { + if (getTokenState(token) != TokenState.EXPIRED) { return token; } lastRefreshSucceeded = false; try { - token = refresh(); // May throw an exception + token = refresh(); } catch (Exception e) { logger.error("Failed to refresh token synchronously", e); throw e; @@ -182,13 +181,9 @@ protected synchronized Token getTokenBlocking() { * @return The current valid or stale token */ protected Token getTokenAsync() { - TokenState state; Token currentToken = token; - synchronized (this) { - state = getTokenState(); - currentToken = token; - } - switch (state) { + + switch (getTokenState(currentToken)) { case FRESH: return currentToken; case STALE: @@ -197,7 +192,7 @@ protected Token getTokenAsync() { case EXPIRED: return getTokenBlocking(); default: - throw new IllegalStateException("Invalid token state: " + state); + throw new IllegalStateException("Invalid token state."); } } From d97c7348bcc05c4da7d7c785e7a671b71cc8f5f2 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 13:24:41 +0000 Subject: [PATCH 12/49] Furter optimizations --- .../core/oauth/RefreshableTokenSource.java | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 769b88731..ce0f05657 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -71,6 +71,8 @@ public RefreshableTokenSource(Token token) { /** * Set the clock supplier for current time. * + *

Experimental: This method may change or be removed in future releases. + * * @param clockSupplier The clock supplier to use. * @return this instance for chaining */ @@ -82,6 +84,8 @@ public RefreshableTokenSource withClockSupplier(ClockSupplier clockSupplier) { /** * Enable or disable asynchronous token refresh. * + *

Experimental: This method may change or be removed in future releases. + * * @param enabled true to enable async refresh, false to disable * @return this instance for chaining */ @@ -94,6 +98,8 @@ public RefreshableTokenSource withAsyncRefresh(boolean enabled) { * Set the expiry buffer. If the token's lifetime is less than this buffer, it is considered * expired. * + *

Experimental: This method may change or be removed in future releases. + * * @param buffer the expiry buffer duration * @return this instance for chaining */ @@ -156,19 +162,28 @@ protected TokenState getTokenState(Token t) { * * @return The current valid token */ - protected synchronized Token getTokenBlocking() { + protected Token getTokenBlocking() { + // Use double-checked locking to minimize synchronization overhead on reads: + // 1. Check if the token is expired without locking. + // 2. If expired, synchronize and check again (another thread may have refreshed it). + // 3. If still expired, perform the refresh. if (getTokenState(token) != TokenState.EXPIRED) { return token; } - lastRefreshSucceeded = false; - try { - token = refresh(); - } catch (Exception e) { - logger.error("Failed to refresh token synchronously", e); - throw e; + synchronized (this) { + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + lastRefreshSucceeded = false; + try { + token = refresh(); + } catch (Exception e) { + logger.error("Failed to refresh token synchronously", e); + throw e; + } + lastRefreshSucceeded = true; + return token; } - lastRefreshSucceeded = true; - return token; } /** From 105bc99b03eb346918c7047d3b3a550215ca9f06 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 14:02:53 +0000 Subject: [PATCH 13/49] Add extra token state check in async refresh --- .../com/databricks/sdk/core/oauth/RefreshableTokenSource.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index ce0f05657..c62dccad7 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -216,7 +216,8 @@ protected Token getTokenAsync() { * succeeded. */ protected synchronized void triggerAsyncRefresh() { - if (!refreshInProgress && lastRefreshSucceeded) { + // Check token state to avoid triggering a refresh if another thread has already refreshed it + if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { refreshInProgress = true; CompletableFuture.runAsync( () -> { From b24a0fce8da59fee5bfe24f35b3e615ab4b91485 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 3 Jun 2025 17:37:45 +0000 Subject: [PATCH 14/49] Change LocalDateTime to Instant --- .../databricks/sdk/core/CliTokenSource.java | 32 +++++++++++++---- .../oauth/DatabricksOAuthTokenSource.java | 4 +-- .../sdk/core/oauth/EndpointTokenSource.java | 4 +-- .../sdk/core/oauth/OidcTokenSource.java | 4 +-- .../core/oauth/RefreshableTokenSource.java | 5 ++- .../com/databricks/sdk/core/oauth/Token.java | 22 ++++++------ .../core/AzureCliCredentialsProviderTest.java | 3 +- .../sdk/core/CliTokenSourceTest.java | 35 +++++++++++++++---- .../sdk/core/DatabricksConfigTest.java | 3 +- .../core/oauth/DataPlaneTokenSourceTest.java | 14 ++++---- .../core/oauth/EndpointTokenSourceTest.java | 4 +-- ...xternalBrowserCredentialsProviderTest.java | 16 ++++----- .../sdk/core/oauth/FileTokenCacheTest.java | 14 ++++---- .../core/oauth/OAuthHeaderFactoryTest.java | 6 ++-- .../TokenSourceCredentialsProviderTest.java | 4 +-- .../databricks/sdk/core/oauth/TokenTest.java | 15 ++++---- 16 files changed, 109 insertions(+), 76 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 8a7328904..3b3845e51 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -8,7 +8,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.io.InputStream; +import java.time.Instant; import java.time.LocalDateTime; +import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.Arrays; @@ -36,21 +38,39 @@ public CliTokenSource( this.env = env; } - static LocalDateTime parseExpiry(String expiry) { + /** + * Parses an expiry time string and returns the corresponding {@link Instant}. + * + *

The expiry time string is verified to always be in UTC. Any time zone or offset information + * present in the input is ignored, and the value is parsed as a UTC time. + * + *

The method attempts to parse the input using several common date-time formats, including + * ISO-8601 and patterns with varying sub-second precision. + * + * @param expiry the expiry time string to parse, which must represent a UTC time + * @return the parsed {@link Instant} representing the expiry time in UTC + * @throws DateTimeParseException if the input string cannot be parsed as a valid date-time + */ + static Instant parseExpiry(String expiry) { + DateTimeParseException lastException = null; + try { + return Instant.parse(expiry); + } catch (DateTimeParseException e) { + lastException = e; + } + String multiplePrecisionPattern = "[SSSSSSSSS][SSSSSSSS][SSSSSSS][SSSSSS][SSSSS][SSSS][SSS][SS][S]"; List datePatterns = Arrays.asList( "yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd HH:mm:ss." + multiplePrecisionPattern, - "yyyy-MM-dd'T'HH:mm:ss." + multiplePrecisionPattern + "XXX", - "yyyy-MM-dd'T'HH:mm:ss." + multiplePrecisionPattern + "'Z'"); - DateTimeParseException lastException = null; + "yyyy-MM-dd'T'HH:mm:ss." + multiplePrecisionPattern + "XXX"); for (String pattern : datePatterns) { try { DateTimeFormatter formatter = DateTimeFormatter.ofPattern(pattern); LocalDateTime dateTime = LocalDateTime.parse(expiry, formatter); - return dateTime; + return dateTime.atZone(ZoneOffset.UTC).toInstant(); } catch (DateTimeParseException e) { lastException = e; } @@ -83,7 +103,7 @@ protected Token refresh() { String tokenType = jsonNode.get(tokenTypeField).asText(); String accessToken = jsonNode.get(accessTokenField).asText(); String expiry = jsonNode.get(expiryField).asText(); - LocalDateTime expiresOn = parseExpiry(expiry); + Instant expiresOn = parseExpiry(expiry); return new Token(accessToken, tokenType, expiresOn); } catch (DatabricksException e) { throw e; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index f16ae2aed..484e0712e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -3,7 +3,7 @@ import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.HttpClient; import com.google.common.base.Strings; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -166,7 +166,7 @@ public Token refresh() { throw e; } - LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); + Instant expiry = Instant.now().plusSeconds(response.getExpiresIn()); return new Token( response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index 3ca75c441..ed08f57d6 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -2,7 +2,7 @@ import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.HttpClient; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -87,7 +87,7 @@ protected Token refresh() { throw e; } - LocalDateTime expiry = LocalDateTime.now().plusSeconds(oauthResponse.getExpiresIn()); + Instant expiry = Instant.now().plusSeconds(oauthResponse.getExpiresIn()); return new Token( oauthResponse.getAccessToken(), oauthResponse.getTokenType(), diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java index 719544ebf..b15f55ded 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java @@ -8,7 +8,7 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import java.io.IOException; -import java.time.LocalDateTime; +import java.time.Instant; /** * {@code OidcTokenSource} is responsible for obtaining OAuth tokens using the OpenID Connect (OIDC) @@ -77,7 +77,7 @@ protected Token refresh() { if (resp.getErrorCode() != null) { throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); } - LocalDateTime expiry = LocalDateTime.now().plusSeconds(resp.getExpiresIn()); + Instant expiry = Instant.now().plusSeconds(resp.getExpiresIn()); return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index e93f91ae5..750aae967 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -5,8 +5,7 @@ import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; -import java.time.LocalDateTime; -import java.time.temporal.ChronoUnit; +import java.time.Instant; import java.util.Base64; import java.util.Map; import org.apache.http.HttpHeaders; @@ -70,7 +69,7 @@ protected static Token retrieveToken( if (resp.getErrorCode() != null) { throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); } - LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS); + Instant expiry = Instant.now().plusSeconds(resp.getExpiresIn()); return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); } catch (Exception e) { throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index f0fd72f68..ac6fbc3ac 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -4,8 +4,7 @@ import com.databricks.sdk.core.utils.SystemClockSupplier; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import java.time.LocalDateTime; -import java.time.temporal.ChronoUnit; +import java.time.Instant; import java.util.Objects; public class Token { @@ -19,21 +18,20 @@ public class Token { * The expiry time of the token. * *

OAuth token responses include the duration of the lifetime of the access token. When the - * token is retrieved, this is converted to a LocalDateTime tracking the expiry time of the token - * with respect to the current clock. + * token is retrieved, this is converted to an Instant tracking the expiry time of the token with + * respect to the current clock. */ - @JsonProperty private LocalDateTime expiry; + @JsonProperty private Instant expiry; private final ClockSupplier clockSupplier; /** Constructor for non-refreshable tokens (e.g. M2M). */ - public Token(String accessToken, String tokenType, LocalDateTime expiry) { + public Token(String accessToken, String tokenType, Instant expiry) { this(accessToken, tokenType, null, expiry, new SystemClockSupplier()); } /** Constructor for non-refreshable tokens (e.g. M2M) with ClockSupplier */ - public Token( - String accessToken, String tokenType, LocalDateTime expiry, ClockSupplier clockSupplier) { + public Token(String accessToken, String tokenType, Instant expiry, ClockSupplier clockSupplier) { this(accessToken, tokenType, null, expiry, clockSupplier); } @@ -43,7 +41,7 @@ public Token( @JsonProperty("accessToken") String accessToken, @JsonProperty("tokenType") String tokenType, @JsonProperty("refreshToken") String refreshToken, - @JsonProperty("expiry") LocalDateTime expiry) { + @JsonProperty("expiry") Instant expiry) { this(accessToken, tokenType, refreshToken, expiry, new SystemClockSupplier()); } @@ -52,7 +50,7 @@ public Token( String accessToken, String tokenType, String refreshToken, - LocalDateTime expiry, + Instant expiry, ClockSupplier clockSupplier) { Objects.requireNonNull(accessToken, "accessToken must be defined"); Objects.requireNonNull(tokenType, "tokenType must be defined"); @@ -71,8 +69,8 @@ public boolean isExpired() { } // Azure Databricks rejects tokens that expire in 30 seconds or less, // so we refresh the token 40 seconds before it expires. - LocalDateTime potentiallyExpired = expiry.minus(40, ChronoUnit.SECONDS); - LocalDateTime now = LocalDateTime.now(clockSupplier.getClock()); + Instant potentiallyExpired = expiry.minusSeconds(40); + Instant now = Instant.now(clockSupplier.getClock()); return potentiallyExpired.isBefore(now); } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java index 0212b7652..4e8a57b06 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -7,7 +7,6 @@ import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.oauth.TokenSource; -import java.time.LocalDateTime; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; @@ -25,7 +24,7 @@ class AzureCliCredentialsProviderTest { private static CliTokenSource mockTokenSource() { CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); Mockito.when(tokenSource.getToken()) - .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + .thenReturn(new Token(TOKEN, TOKEN_TYPE, java.time.Instant.now())); return tokenSource; } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 20e2f6095..abe609b01 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -2,7 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import java.time.LocalDateTime; +import java.time.Instant; import java.time.format.DateTimeParseException; import org.junit.jupiter.api.Test; @@ -10,20 +10,20 @@ public class CliTokenSourceTest { @Test public void testParseExpiryWithoutTruncate() { - LocalDateTime parsedDateTime = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612218Z"); - assertEquals(LocalDateTime.of(2023, 7, 17, 9, 2, 22, 330612218), parsedDateTime); + Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612218Z"); + assertEquals(Instant.parse("2023-07-17T09:02:22.330612218Z"), parsedInstant); } @Test public void testParseExpiryWithTruncate() { - LocalDateTime parsedDateTime = CliTokenSource.parseExpiry("2023-07-17T09:02:22.33061221Z"); - assertEquals(LocalDateTime.of(2023, 7, 17, 9, 2, 22, 330612210), parsedDateTime); + Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17T09:02:22.33061221Z"); + assertEquals(Instant.parse("2023-07-17T09:02:22.330612210Z"), parsedInstant); } @Test public void testParseExpiryWithTruncateAndLessNanoSecondDigits() { - LocalDateTime parsedDateTime = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612Z"); - assertEquals(LocalDateTime.of(2023, 7, 17, 9, 2, 22, 330612000), parsedDateTime); + Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612Z"); + assertEquals(Instant.parse("2023-07-17T09:02:22.330612000Z"), parsedInstant); } @Test @@ -34,4 +34,25 @@ public void testParseExpiryWithMoreThanNineNanoSecondDigits() { assert (e.getMessage().contains("could not be parsed")); } } + + @Test + public void testParseExpiryWithSpaceFormat() { + Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17 09:02:22"); + assertEquals(Instant.parse("2023-07-17T09:02:22Z"), parsedInstant); + } + + @Test + public void testParseExpiryWithSpaceFormatAndMillis() { + Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17 09:02:22.123"); + assertEquals(Instant.parse("2023-07-17T09:02:22.123Z"), parsedInstant); + } + + @Test + public void testParseExpiryWithInvalidFormat() { + try { + CliTokenSource.parseExpiry("17-07-2023 09:02:22"); + } catch (DateTimeParseException e) { + assert (e.getMessage().contains("could not be parsed")); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java index 38b6fcd9c..b3ac333a3 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java @@ -12,7 +12,6 @@ import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Environment; import java.io.IOException; -import java.time.LocalDateTime; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -232,7 +231,7 @@ public void testGetTokenSourceWithOAuth() { HttpClient httpClient = mock(HttpClient.class); TokenSource mockTokenSource = mock(TokenSource.class); when(mockTokenSource.getToken()) - .thenReturn(new Token("test-token", "Bearer", LocalDateTime.now().plusHours(1))); + .thenReturn(new Token("test-token", "Bearer", java.time.Instant.now().plusSeconds(3600))); OAuthHeaderFactory mockHeaderFactory = OAuthHeaderFactory.fromTokenSource(mockTokenSource); CredentialsProvider mockProvider = mock(CredentialsProvider.class); when(mockProvider.authType()).thenReturn("test"); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java index 5887c4ee1..1fb96a559 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -8,7 +8,7 @@ import com.databricks.sdk.core.http.Response; import java.io.IOException; import java.net.URL; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -29,8 +29,7 @@ public class DataPlaneTokenSourceTest { private static Stream provideDataPlaneTokenScenarios() throws Exception { // Mock DatabricksOAuthTokenSource for control plane token - Token cpToken = - new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(600)); + Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, Instant.now().plusSeconds(600)); DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); when(mockCpTokenSource.getToken()).thenReturn(cpToken); @@ -81,9 +80,8 @@ private static Stream provideDataPlaneTokenScenarios() throws Excepti "dp-access-token1", TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, - LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), - null // No exception - ), + Instant.now().plusSeconds(TEST_EXPIRES_IN)), + null), Arguments.of( "Success: endpoint2/auth2 (different cache key)", TEST_ENDPOINT_2, @@ -95,7 +93,7 @@ private static Stream provideDataPlaneTokenScenarios() throws Excepti "dp-access-token2", TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, - LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + Instant.now().plusSeconds(TEST_EXPIRES_IN)), null), Arguments.of( "Error response from endpoint", @@ -203,7 +201,7 @@ void testDataPlaneTokenSource( @Test void testEndpointTokenSourceCaching() throws Exception { Token cpToken = - new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(3600)); + new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, Instant.now().plusSeconds(3600)); DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); when(mockCpTokenSource.getToken()).thenReturn(cpToken); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java index a3af2254f..5fdac9f2d 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java @@ -9,7 +9,7 @@ import com.databricks.sdk.core.http.Response; import java.io.IOException; import java.net.URL; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -48,7 +48,7 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio String malformedJson = "{not valid json}"; // Mock DatabricksOAuthTokenSource for control plane token - Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, LocalDateTime.now().plusMinutes(10)); + Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, Instant.now().plusSeconds(600)); DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); when(mockCpTokenSource.getToken()).thenReturn(cpToken); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java index 1714b731c..932690bd7 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java @@ -13,7 +13,7 @@ import com.databricks.sdk.core.http.Response; import java.io.IOException; import java.net.URL; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -210,7 +210,7 @@ void sessionCredentials() throws IOException { "originalAccessToken", "originalTokenType", "originalRefreshToken", - LocalDateTime.MAX)) + Instant.MAX)) .build(); Token token = sessionCredentials.refresh(); @@ -234,7 +234,7 @@ void cacheWithValidTokenTest() throws IOException { .execute(any(Request.class)); // Create an valid token with valid refresh token - LocalDateTime futureTime = LocalDateTime.now().plusHours(1); + Instant futureTime = Instant.now().plusSeconds(3600); Token validToken = new Token("valid_access_token", "Bearer", "valid_refresh_token", futureTime); // Create mock token cache that returns the valid token @@ -314,7 +314,7 @@ void cacheWithInvalidAccessTokenValidRefreshTest() throws IOException { .execute(any(Request.class)); // Create an expired token with valid refresh token - LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Instant pastTime = Instant.now().minusSeconds(3600); Token expiredToken = new Token("expired_access_token", "Bearer", "valid_refresh_token", pastTime); @@ -392,7 +392,7 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { .execute(any(Request.class)); // Create an expired token with invalid refresh token - LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Instant pastTime = Instant.now().minusSeconds(3600); Token expiredToken = new Token("expired_access_token", "Bearer", "invalid_refresh_token", pastTime); @@ -406,7 +406,7 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { "browser_access_token", "Bearer", "browser_refresh_token", - LocalDateTime.now().plusHours(1)); + Instant.now().plusSeconds(3600)); SessionCredentials browserAuthCreds = new SessionCredentials.Builder() @@ -460,7 +460,7 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { @Test void cacheWithInvalidTokensTest() throws IOException { // Create completely invalid token (no refresh token) - LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Instant pastTime = Instant.now().minusSeconds(3600); Token invalidToken = new Token("expired_access_token", "Bearer", null, pastTime); // Create mock token cache that returns the invalid token @@ -473,7 +473,7 @@ void cacheWithInvalidTokensTest() throws IOException { "browser_access_token", "Bearer", "browser_refresh_token", - LocalDateTime.now().plusHours(1)); + Instant.now().plusSeconds(3600)); SessionCredentials browserAuthCreds = new SessionCredentials.Builder() diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java index ede6cfd11..303f6de66 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java @@ -5,7 +5,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.AfterEach; @@ -42,7 +42,7 @@ void testEmptyCache() { @Test void testSaveAndLoadToken() { // Given a token - LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Instant expiry = Instant.now().plusSeconds(3600); Token token = new Token("access-token", "Bearer", "refresh-token", expiry); // When saving and loading the token @@ -60,14 +60,14 @@ void testSaveAndLoadToken() { @Test void testTokenExpiry() { // Create an expired token - LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Instant pastTime = Instant.now().minusSeconds(3600); Token expiredToken = new Token("access-token", "Bearer", "refresh-token", pastTime); // Verify it's marked as expired assertTrue(expiredToken.isExpired(), "Token should be expired"); // Create a valid token - LocalDateTime futureTime = LocalDateTime.now().plusMinutes(30); + Instant futureTime = Instant.now().plusSeconds(1800); Token validToken = new Token("access-token", "Bearer", "refresh-token", futureTime); // Verify it's not marked as expired @@ -86,8 +86,8 @@ void testNullPathRejection() { @Test void testOverwriteToken() { // Given two tokens saved in sequence - Token token1 = new Token("token1", "Bearer", "refresh1", LocalDateTime.now().plusHours(1)); - Token token2 = new Token("token2", "Bearer", "refresh2", LocalDateTime.now().plusHours(2)); + Token token1 = new Token("token1", "Bearer", "refresh1", Instant.now().plusSeconds(3600)); + Token token2 = new Token("token2", "Bearer", "refresh2", Instant.now().plusSeconds(7200)); tokenCache.save(token1); tokenCache.save(token2); @@ -110,7 +110,7 @@ void testWithCustomPath(@TempDir Path tempDir) { // And a token Token testToken = new Token( - "test-access-token", "Bearer", "test-refresh-token", LocalDateTime.now().plusHours(1)); + "test-access-token", "Bearer", "test-refresh-token", Instant.now().plusSeconds(3600)); // When saving and loading cache.save(testToken); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java index d0530b2c1..f0b83153c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java @@ -3,7 +3,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -25,7 +25,7 @@ public class OAuthHeaderFactoryTest { @Mock private TokenSource tokenSource; private static Stream provideTokenSourceTestCases() { - LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Instant expiry = Instant.now().plusSeconds(3600); Token token = new Token(TOKEN_VALUE, TOKEN_TYPE, expiry); return Stream.of( @@ -57,7 +57,7 @@ public void testFromTokenSourceFactoryMethod( } private static Stream provideSuppliersTestCases() { - LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Instant expiry = Instant.now().plusSeconds(3600); Token token = new Token(TOKEN_VALUE, TOKEN_TYPE, expiry); Map standardHeaders = new HashMap<>(); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java index 14eb3fa40..8d2d68fd4 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java @@ -6,7 +6,7 @@ import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.HeaderFactory; -import java.time.LocalDateTime; +import java.time.Instant; import java.util.Map; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; @@ -19,7 +19,7 @@ class TokenSourceCredentialsProviderTest { private static final String TEST_AUTH_TYPE = "test-auth-type"; private static final String TEST_ACCESS_TOKEN_VALUE = "test-access-token"; private static final Token TEST_TOKEN = - new Token(TEST_ACCESS_TOKEN_VALUE, "Bearer", LocalDateTime.now().plusHours(1)); + new Token(TEST_ACCESS_TOKEN_VALUE, "Bearer", Instant.now().plusSeconds(3600)); /** Tests token retrieval scenarios */ @ParameterizedTest(name = "{0}") diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java index 2d87a32c2..16b726222 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java @@ -4,7 +4,6 @@ import com.databricks.sdk.core.utils.FakeClockSupplier; import java.time.Instant; -import java.time.LocalDateTime; import java.time.ZoneId; import org.junit.jupiter.api.Test; @@ -13,20 +12,20 @@ class TokenTest { private static final String accessToken = "testAccessToken"; private static final String refreshToken = "testRefreshToken"; private static final String tokenType = "testTokenType"; - private final LocalDateTime currentLocalDateTime; + private final Instant currentInstant; private final FakeClockSupplier fakeClockSupplier; TokenTest() { Instant instant = Instant.parse("2023-10-18T12:00:00.00Z"); ZoneId zoneId = ZoneId.of("UTC"); fakeClockSupplier = new FakeClockSupplier(instant, zoneId); - currentLocalDateTime = LocalDateTime.now(fakeClockSupplier.getClock()); + currentInstant = Instant.now(fakeClockSupplier.getClock()); } @Test void createNonRefreshableToken() { Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusMinutes(5), fakeClockSupplier); + new Token(accessToken, tokenType, currentInstant.plusSeconds(300), fakeClockSupplier); assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertNull(token.getRefreshToken()); @@ -40,7 +39,7 @@ void createRefreshableToken() { accessToken, tokenType, refreshToken, - currentLocalDateTime.plusMinutes(5), + currentInstant.plusSeconds(300), fakeClockSupplier); assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); @@ -51,7 +50,7 @@ void createRefreshableToken() { @Test void tokenExpiryMoreThan40Seconds() { Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusSeconds(50), fakeClockSupplier); + new Token(accessToken, tokenType, currentInstant.plusSeconds(50), fakeClockSupplier); assertFalse(token.isExpired()); assertTrue(token.isValid()); } @@ -59,7 +58,7 @@ void tokenExpiryMoreThan40Seconds() { @Test void tokenExpiryLessThan40Seconds() { Token token = - new Token(accessToken, tokenType, currentLocalDateTime.plusSeconds(30), fakeClockSupplier); + new Token(accessToken, tokenType, currentInstant.plusSeconds(30), fakeClockSupplier); assertTrue(token.isExpired()); assertFalse(token.isValid()); } @@ -67,7 +66,7 @@ void tokenExpiryLessThan40Seconds() { @Test void expiredToken() { Token token = - new Token(accessToken, tokenType, currentLocalDateTime.minusSeconds(10), fakeClockSupplier); + new Token(accessToken, tokenType, currentInstant.minusSeconds(10), fakeClockSupplier); assertTrue(token.isExpired()); assertFalse(token.isValid()); } From 6f81b4cff985b9ad06665189c39eac630d1a5806 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 4 Jun 2025 08:51:23 +0000 Subject: [PATCH 15/49] Update parseExpiry in CilTokenSource --- .../main/java/com/databricks/sdk/core/CliTokenSource.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 3b3845e51..9eec8930f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -41,8 +41,8 @@ public CliTokenSource( /** * Parses an expiry time string and returns the corresponding {@link Instant}. * - *

The expiry time string is verified to always be in UTC. Any time zone or offset information - * present in the input is ignored, and the value is parsed as a UTC time. + *

The expiry time string is always in UTC. Any time zone or offset information present in the + * input is ignored. * *

The method attempts to parse the input using several common date-time formats, including * ISO-8601 and patterns with varying sub-second precision. @@ -53,6 +53,7 @@ public CliTokenSource( */ static Instant parseExpiry(String expiry) { DateTimeParseException lastException = null; + // Try to parse the expiry as an ISO-8601 string in UTC first try { return Instant.parse(expiry); } catch (DateTimeParseException e) { From daac1b2f8ec1935c9ee4be4553fc615ac44915ff Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 4 Jun 2025 09:02:34 +0000 Subject: [PATCH 16/49] Update javadoc --- .../src/main/java/com/databricks/sdk/core/CliTokenSource.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 9eec8930f..d07bbf787 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -49,7 +49,7 @@ public CliTokenSource( * * @param expiry the expiry time string to parse, which must represent a UTC time * @return the parsed {@link Instant} representing the expiry time in UTC - * @throws DateTimeParseException if the input string cannot be parsed as a valid date-time + * @throws DateTimeParseException if the input string cannot be parsed */ static Instant parseExpiry(String expiry) { DateTimeParseException lastException = null; From f3d4b8a669d12f6bb3ffb56054e71c7f6cdb7eff Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 4 Jun 2025 09:03:51 +0000 Subject: [PATCH 17/49] Update javadoc --- .../src/main/java/com/databricks/sdk/core/CliTokenSource.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index d07bbf787..2e176c210 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -48,7 +48,7 @@ public CliTokenSource( * ISO-8601 and patterns with varying sub-second precision. * * @param expiry the expiry time string to parse, which must represent a UTC time - * @return the parsed {@link Instant} representing the expiry time in UTC + * @return the parsed {@link Instant} * @throws DateTimeParseException if the input string cannot be parsed */ static Instant parseExpiry(String expiry) { From 12123a957402ab4bc492de71cb240b7842211a84 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 5 Jun 2025 09:15:03 +0000 Subject: [PATCH 18/49] Retrigger tests From 8705ff53e7017d2212acb839edc2efdec6683410 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 6 Jun 2025 08:55:42 +0000 Subject: [PATCH 19/49] Save progress --- .../databricks/sdk/core/CliTokenSource.java | 6 +- .../databricks/sdk/core/DatabricksConfig.java | 8 +- ...reServicePrincipalCredentialsProvider.java | 6 +- .../sdk/core/oauth/CachedTokenSource.java | 220 ++++++++++++++ .../sdk/core/oauth/ClientCredentials.java | 7 +- .../databricks/sdk/core/oauth/Consent.java | 2 +- .../oauth/DatabricksOAuthTokenSource.java | 4 +- .../sdk/core/oauth/EndpointTokenSource.java | 4 +- .../ExternalBrowserCredentialsProvider.java | 2 +- .../sdk/core/oauth/OidcTokenSource.java | 5 +- .../core/oauth/RefreshableTokenSource.java | 287 +----------------- .../sdk/core/oauth/SessionCredentials.java | 10 +- .../sdk/core/oauth/TokenEndpointClient.java | 60 ++++ .../databricks/sdk/core/utils/AzureUtils.java | 2 +- ...ceTest.java => CachedTokenSourceTest.java} | 80 +++-- ...xternalBrowserCredentialsProviderTest.java | 4 +- 16 files changed, 348 insertions(+), 359 deletions(-) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java rename databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/{RefreshableTokenSourceTest.java => CachedTokenSourceTest.java} (64%) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 8a7328904..2eb5c13f7 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -1,7 +1,7 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.RefreshableTokenSource; import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Environment; import com.databricks.sdk.core.utils.OSUtils; import com.fasterxml.jackson.databind.JsonNode; @@ -15,7 +15,7 @@ import java.util.List; import org.apache.commons.io.IOUtils; -public class CliTokenSource extends RefreshableTokenSource { +public class CliTokenSource implements TokenSource { private List cmd; private String tokenTypeField; private String accessTokenField; @@ -64,7 +64,7 @@ private String getProcessStream(InputStream stream) throws IOException { } @Override - protected Token refresh() { + public Token getToken() { try { ProcessBuilder processBuilder = new ProcessBuilder(cmd); processBuilder.environment().putAll(env.getEnv()); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index de6548982..df16ebae3 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -410,13 +410,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) { return this; } - /** @deprecated Use {@link #getAzureUseMsi()} instead. */ + /** + * @deprecated Use {@link #getAzureUseMsi()} instead. + */ @Deprecated() public boolean getAzureUseMSI() { return azureUseMsi; } - /** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */ + /** + * @deprecated Use {@link #setAzureUseMsi(boolean)} instead. + */ @Deprecated public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) { this.azureUseMsi = azureUseMsi; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index c7c7bb672..d10fefd94 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -28,8 +28,8 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { } AzureUtils.ensureHostPresent( config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor); - RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - RefreshableTokenSource cloud = + TokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + TokenSource cloud = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); return OAuthHeaderFactory.fromSuppliers( @@ -55,7 +55,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { * @return A RefreshableTokenSource instance capable of fetching OAuth tokens for the specified * Azure resource. */ - private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) { + private static TokenSource tokenSourceFor(DatabricksConfig config, String resource) { String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; Map endpointParams = new HashMap<>(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java new file mode 100644 index 000000000..72aa38f74 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java @@ -0,0 +1,220 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.utils.ClockSupplier; +import com.databricks.sdk.core.utils.SystemClockSupplier; +import java.time.Duration; +import java.time.LocalDateTime; +import java.util.concurrent.CompletableFuture; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An OAuth TokenSource which can be refreshed. + * + *

This class supports both synchronous and asynchronous token refresh. When async is enabled, + * stale tokens will trigger a background refresh, while expired tokens will block until a new token + * is fetched. + */ +public class CachedTokenSource implements TokenSource { + + /** + * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: + * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: + * Token has expired and must be refreshed using a blocking call. + */ + private enum TokenState { + FRESH, + STALE, + EXPIRED + } + + private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); + // Default duration before expiry to consider a token as 'stale'. + private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); + private static final Duration DEFAULT_EXPIRY_BUFFER = Duration.ofSeconds(40); + + private final TokenSource tokenSource; + private final boolean asyncEnabled; + private final Duration staleDuration; + private final Duration expiryBuffer; + private final ClockSupplier clockSupplier; + + // The current OAuth token. May be null if not yet fetched. + private volatile Token token; + // Whether a refresh is currently in progress (for async refresh). + private volatile boolean refreshInProgress = false; + // Whether the last refresh attempt succeeded. + private volatile boolean lastRefreshSucceeded = true; + + private CachedTokenSource(Builder builder) { + this.tokenSource = builder.tokenSource; + this.asyncEnabled = builder.asyncEnabled; + this.staleDuration = builder.staleDuration; + this.expiryBuffer = builder.expiryBuffer; + this.clockSupplier = builder.clockSupplier; + this.token = builder.token; + } + + public static class Builder { + private final TokenSource tokenSource; + private Token token; + private boolean asyncEnabled = false; + private Duration staleDuration = DEFAULT_STALE_DURATION; + private Duration expiryBuffer = DEFAULT_EXPIRY_BUFFER; + private ClockSupplier clockSupplier = new SystemClockSupplier(); + + public Builder(TokenSource tokenSource) { + this.tokenSource = tokenSource; + } + + public Builder withToken(Token token) { + this.token = token; + return this; + } + + public Builder withAsyncEnabled(boolean asyncEnabled) { + this.asyncEnabled = asyncEnabled; + return this; + } + + public Builder withStaleDuration(Duration staleDuration) { + this.staleDuration = staleDuration; + return this; + } + + public Builder withExpiryBuffer(Duration expiryBuffer) { + this.expiryBuffer = expiryBuffer; + return this; + } + + public Builder withClockSupplier(ClockSupplier clockSupplier) { + this.clockSupplier = clockSupplier; + return this; + } + + public CachedTokenSource build() { + return new CachedTokenSource(this); + } + } + + /** + * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a + * stale token while a refresh is in progress. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid token + */ + public Token getToken() { + if (!asyncEnabled) { + return getTokenBlocking(); + } + return getTokenAsync(); + } + + /** + * Determine the state of the current token (fresh, stale, or expired). + * + * @return The token state + */ + protected TokenState getTokenState(Token t) { + if (t == null) { + return TokenState.EXPIRED; + } + Duration lifeTime = + Duration.between(LocalDateTime.now(clockSupplier.getClock()), t.getExpiry()); + if (lifeTime.compareTo(expiryBuffer) <= 0) { + return TokenState.EXPIRED; + } + if (lifeTime.compareTo(staleDuration) <= 0) { + return TokenState.STALE; + } + return TokenState.FRESH; + } + + /** + * Get the current token, blocking to refresh if expired. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid token + */ + protected Token getTokenBlocking() { + // Use double-checked locking to minimize synchronization overhead on reads: + // 1. Check if the token is expired without locking. + // 2. If expired, synchronize and check again (another thread may have refreshed it). + // 3. If still expired, perform the refresh. + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + synchronized (this) { + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + lastRefreshSucceeded = false; + try { + token = tokenSource.getToken(); + } catch (Exception e) { + logger.error("Failed to refresh token synchronously", e); + throw e; + } + lastRefreshSucceeded = true; + return token; + } + } + + /** + * Get the current token, possibly triggering an async refresh if stale. If the token is expired, + * blocks to refresh. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid or stale token + */ + protected Token getTokenAsync() { + Token currentToken = token; + + switch (getTokenState(currentToken)) { + case FRESH: + return currentToken; + case STALE: + triggerAsyncRefresh(); + return currentToken; + case EXPIRED: + return getTokenBlocking(); + default: + throw new IllegalStateException("Invalid token state."); + } + } + + /** + * Trigger an asynchronous refresh of the token if not already in progress and last refresh + * succeeded. + */ + protected synchronized void triggerAsyncRefresh() { + // Check token state to avoid triggering a refresh if another thread has already refreshed it + if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { + refreshInProgress = true; + CompletableFuture.runAsync( + () -> { + try { + // Attempt to refresh the token in the background + Token newToken = tokenSource.getToken(); + synchronized (this) { + token = newToken; + refreshInProgress = false; + } + } catch (Exception e) { + synchronized (this) { + lastRefreshSucceeded = false; + refreshInProgress = false; + logger.error("Async token refresh failed", e); + } + } + }); + } + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java index 1c4b7d6de..8cee3ef29 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java @@ -13,7 +13,7 @@ * support all OAuth endpoints, authentication parameters can be passed in the request body or in * the Authorization header. */ -public class ClientCredentials extends RefreshableTokenSource { +public class ClientCredentials implements TokenSource { public static class Builder { private String clientId; private String clientSecret; @@ -97,7 +97,7 @@ private ClientCredentials( } @Override - protected Token refresh() { + public Token getToken() { Map params = new HashMap<>(); params.put("grant_type", "client_credentials"); if (scopes != null) { @@ -106,6 +106,7 @@ protected Token refresh() { if (endpointParamsSupplier != null) { params.putAll(endpointParamsSupplier.get()); } - return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); + return TokenEndpointClient.retrieveToken( + hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java index 77045df97..68dc6f176 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java @@ -313,7 +313,7 @@ public SessionCredentials exchange(String code, String state) { headers.put("Origin", this.redirectUrl); } Token token = - RefreshableTokenSource.retrieveToken( + TokenEndpointClient.retrieveToken( this.hc, this.clientId, this.clientSecret, diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index f16ae2aed..ac76234a0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -14,7 +14,7 @@ * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ -public class DatabricksOAuthTokenSource extends RefreshableTokenSource { +public class DatabricksOAuthTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class); /** OAuth client ID used for token exchange. */ @@ -128,7 +128,7 @@ public DatabricksOAuthTokenSource build() { * @throws NullPointerException when any of the required parameters are null. */ @Override - public Token refresh() { + public Token getToken() { Objects.requireNonNull(clientId, "ClientID cannot be null"); Objects.requireNonNull(host, "Host cannot be null"); Objects.requireNonNull(endpoints, "Endpoints cannot be null"); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index 3ca75c441..7bcddacdd 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -13,7 +13,7 @@ * Represents a token source that exchanges a control plane token for an endpoint-specific dataplane * token. It utilizes an underlying {@link TokenSource} to obtain the initial control plane token. */ -public class EndpointTokenSource extends RefreshableTokenSource { +public class EndpointTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; private static final String GRANT_TYPE_PARAM = "grant_type"; @@ -67,7 +67,7 @@ public EndpointTokenSource( * @throws NullPointerException if any of the parameters are null. */ @Override - protected Token refresh() { + public Token getToken() { Token cpToken = cpTokenSource.getToken(); Map params = new HashMap<>(); params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index 7bae60022..494ad69cd 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -79,7 +79,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { .build(); LOGGER.debug("Using cached token, will immediately refresh"); - cachedCreds.token = cachedCreds.refresh(); + cachedCreds.token = cachedCreds.getToken(); return cachedCreds.configure(config); } catch (Exception e) { // If token refresh fails, log and continue to browser auth diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java index 719544ebf..1739c5c6c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java @@ -15,7 +15,7 @@ * protocol. It communicates with an OAuth server to request access tokens using the client * credentials grant type instead of a client secret. */ -class OidcTokenSource extends RefreshableTokenSource { +class OidcTokenSource implements TokenSource { private final HttpClient httpClient; private final String tokenUrl; @@ -58,7 +58,8 @@ private static void putIfDefined( } } - protected Token refresh() { + @Override + public Token getToken() { Response rawResp; try { rawResp = httpClient.execute(new FormRequest(tokenUrl, params)); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index c62dccad7..c0cf080ce 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -1,22 +1,5 @@ package com.databricks.sdk.core.oauth; -import com.databricks.sdk.core.ApiClient; -import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.http.FormRequest; -import com.databricks.sdk.core.http.HttpClient; -import com.databricks.sdk.core.http.Request; -import com.databricks.sdk.core.utils.ClockSupplier; -import com.databricks.sdk.core.utils.SystemClockSupplier; -import java.time.Duration; -import java.time.LocalDateTime; -import java.time.temporal.ChronoUnit; -import java.util.Base64; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import org.apache.http.HttpHeaders; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - /** * An OAuth TokenSource which can be refreshed. * @@ -24,272 +7,4 @@ * stale tokens will trigger a background refresh, while expired tokens will block until a new token * is fetched. */ -public abstract class RefreshableTokenSource implements TokenSource { - - /** - * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: - * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: - * Token has expired and must be refreshed using a blocking call. - */ - private enum TokenState { - FRESH, - STALE, - EXPIRED - } - - private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); - // Default duration before expiry to consider a token as 'stale'. - private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); - - // The current OAuth token. May be null if not yet fetched. - protected volatile Token token; - // Whether asynchronous refresh is enabled. - private boolean asyncEnabled = false; - // Duration before expiry to consider a token as 'stale'. - private Duration staleDuration = DEFAULT_STALE_DURATION; - // Additional buffer before expiry to consider a token as expired. - private Duration expiryBuffer = Duration.ofSeconds(40); - // Whether a refresh is currently in progress (for async refresh). - private boolean refreshInProgress = false; - // Whether the last refresh attempt succeeded. - private boolean lastRefreshSucceeded = true; - // Clock supplier for current time, for testing purposes. - private ClockSupplier clockSupplier = new SystemClockSupplier(); - - /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ - public RefreshableTokenSource() {} - - /** - * Constructor with initial token. - * - * @param token The initial token to use. - */ - public RefreshableTokenSource(Token token) { - this.token = token; - } - - /** - * Set the clock supplier for current time. - * - *

Experimental: This method may change or be removed in future releases. - * - * @param clockSupplier The clock supplier to use. - * @return this instance for chaining - */ - public RefreshableTokenSource withClockSupplier(ClockSupplier clockSupplier) { - this.clockSupplier = clockSupplier; - return this; - } - - /** - * Enable or disable asynchronous token refresh. - * - *

Experimental: This method may change or be removed in future releases. - * - * @param enabled true to enable async refresh, false to disable - * @return this instance for chaining - */ - public RefreshableTokenSource withAsyncRefresh(boolean enabled) { - this.asyncEnabled = enabled; - return this; - } - - /** - * Set the expiry buffer. If the token's lifetime is less than this buffer, it is considered - * expired. - * - *

Experimental: This method may change or be removed in future releases. - * - * @param buffer the expiry buffer duration - * @return this instance for chaining - */ - public RefreshableTokenSource withExpiryBuffer(Duration buffer) { - this.expiryBuffer = buffer; - return this; - } - - /** - * Refresh the OAuth token. Subclasses must implement this to define how the token is refreshed. - * - *

This method may throw an exception if the token cannot be refreshed. The specific exception - * type depends on the implementation. - * - * @return The newly refreshed Token. - */ - protected abstract Token refresh(); - - /** - * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a - * stale token while a refresh is in progress. - * - *

This method may throw an exception if the token cannot be refreshed, depending on the - * implementation of {@link #refresh()}. - * - * @return The current valid token - */ - public Token getToken() { - if (!asyncEnabled) { - return getTokenBlocking(); - } - return getTokenAsync(); - } - - /** - * Determine the state of the current token (fresh, stale, or expired). - * - * @return The token state - */ - protected TokenState getTokenState(Token t) { - if (t == null) { - return TokenState.EXPIRED; - } - Duration lifeTime = - Duration.between(LocalDateTime.now(clockSupplier.getClock()), t.getExpiry()); - if (lifeTime.compareTo(expiryBuffer) <= 0) { - return TokenState.EXPIRED; - } - if (lifeTime.compareTo(staleDuration) <= 0) { - return TokenState.STALE; - } - return TokenState.FRESH; - } - - /** - * Get the current token, blocking to refresh if expired. - * - *

This method may throw an exception if the token cannot be refreshed, depending on the - * implementation of {@link #refresh()}. - * - * @return The current valid token - */ - protected Token getTokenBlocking() { - // Use double-checked locking to minimize synchronization overhead on reads: - // 1. Check if the token is expired without locking. - // 2. If expired, synchronize and check again (another thread may have refreshed it). - // 3. If still expired, perform the refresh. - if (getTokenState(token) != TokenState.EXPIRED) { - return token; - } - synchronized (this) { - if (getTokenState(token) != TokenState.EXPIRED) { - return token; - } - lastRefreshSucceeded = false; - try { - token = refresh(); - } catch (Exception e) { - logger.error("Failed to refresh token synchronously", e); - throw e; - } - lastRefreshSucceeded = true; - return token; - } - } - - /** - * Get the current token, possibly triggering an async refresh if stale. If the token is expired, - * blocks to refresh. - * - *

This method may throw an exception if the token cannot be refreshed, depending on the - * implementation of {@link #refresh()}. - * - * @return The current valid or stale token - */ - protected Token getTokenAsync() { - Token currentToken = token; - - switch (getTokenState(currentToken)) { - case FRESH: - return currentToken; - case STALE: - triggerAsyncRefresh(); - return currentToken; - case EXPIRED: - return getTokenBlocking(); - default: - throw new IllegalStateException("Invalid token state."); - } - } - - /** - * Trigger an asynchronous refresh of the token if not already in progress and last refresh - * succeeded. - */ - protected synchronized void triggerAsyncRefresh() { - // Check token state to avoid triggering a refresh if another thread has already refreshed it - if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { - refreshInProgress = true; - CompletableFuture.runAsync( - () -> { - try { - // Attempt to refresh the token in the background - Token newToken = refresh(); - synchronized (this) { - token = newToken; - refreshInProgress = false; - } - } catch (Exception e) { - synchronized (this) { - lastRefreshSucceeded = false; - refreshInProgress = false; - logger.error("Async token refresh failed", e); - } - } - }); - } - } - - /** - * Helper method implementing OAuth token refresh. - * - * @param hc The HTTP client to use for the request. - * @param clientId The client ID to authenticate with. - * @param clientSecret The client secret to authenticate with. - * @param tokenUrl The authorization URL for fetching tokens. - * @param params Additional request parameters. - * @param headers Additional headers. - * @param position The position of the authentication parameters in the request. - * @return The newly fetched Token. - * @throws DatabricksException if the refresh fails - * @throws IllegalArgumentException if the OAuth response contains an error - */ - protected static Token retrieveToken( - HttpClient hc, - String clientId, - String clientSecret, - String tokenUrl, - Map params, - Map headers, - AuthParameterPosition position) { - switch (position) { - case BODY: - if (clientId != null) { - params.put("client_id", clientId); - } - if (clientSecret != null) { - params.put("client_secret", clientSecret); - } - break; - case HEADER: - String authHeaderValue = - "Basic " - + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); - headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue); - break; - } - headers.put("Content-Type", "application/x-www-form-urlencoded"); - Request req = new Request("POST", tokenUrl, FormRequest.wrapValuesInList(params)); - req.withHeaders(headers); - try { - ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); - OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); - if (resp.getErrorCode() != null) { - throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); - } - LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS); - return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); - } catch (Exception e) { - throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); - } - } -} +public abstract class RefreshableTokenSource implements TokenSource {} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java index 4d2d512e3..3bde86994 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java @@ -17,8 +17,7 @@ * requests to an API, and a long-lived refresh token, which can be used to fetch new access tokens. * Calling refresh() uses the refresh token to retrieve a new access token to authenticate to APIs. */ -public class SessionCredentials extends RefreshableTokenSource - implements CredentialsProvider, Serializable { +public class SessionCredentials implements TokenSource, CredentialsProvider, Serializable { private static final long serialVersionUID = 3083941540130596650L; private static final Logger LOGGER = LoggerFactory.getLogger(SessionCredentials.class); @@ -87,9 +86,10 @@ public SessionCredentials build() { private final String clientId; private final String clientSecret; private final TokenCache tokenCache; + protected Token token; private SessionCredentials(Builder b) { - super(b.token); + this.token = b.token; this.hc = b.hc; this.tokenUrl = b.tokenUrl; this.redirectUrl = b.redirectUrl; @@ -99,7 +99,7 @@ private SessionCredentials(Builder b) { } @Override - protected Token refresh() { + public Token getToken() { if (this.token == null) { throw new DatabricksException("oauth2: token is not set"); } @@ -118,7 +118,7 @@ protected Token refresh() { headers.put("Origin", redirectUrl); } Token newToken = - retrieveToken( + TokenEndpointClient.retrieveToken( hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY); // Save the refreshed token directly to cache diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java index 69883dd24..ace0c9314 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -1,5 +1,6 @@ package com.databricks.sdk.core.oauth; +import com.databricks.sdk.core.ApiClient; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; @@ -88,4 +89,63 @@ public static OAuthResponse requestToken( LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); return response; } + + /** + * Helper method implementing OAuth token refresh. + * + * @param hc The HTTP client to use for the request. + * @param clientId The client ID to authenticate with. + * @param clientSecret The client secret to authenticate with. + * @param tokenUrl The authorization URL for fetching tokens. + * @param params Additional request parameters. + * @param headers Additional headers. + * @param position The position of the authentication parameters in the request. + * @return The newly fetched Token. + * @throws DatabricksException if the refresh fails + * @throws IllegalArgumentException if the OAuth response contains an error + */ + public static Token retrieveToken( + HttpClient hc, + String clientId, + String clientSecret, + String tokenUrl, + Map params, + Map headers, + AuthParameterPosition position) { + switch (position) { + case BODY: + if (clientId != null) { + params.put("client_id", clientId); + } + if (clientSecret != null) { + params.put("client_secret", clientSecret); + } + break; + case HEADER: + String authHeaderValue = + "Basic " + + java.util.Base64.getEncoder() + .encodeToString((clientId + ":" + clientSecret).getBytes()); + headers.put(org.apache.http.HttpHeaders.AUTHORIZATION, authHeaderValue); + break; + } + headers.put("Content-Type", "application/x-www-form-urlencoded"); + com.databricks.sdk.core.http.Request req = + new com.databricks.sdk.core.http.Request( + "POST", tokenUrl, com.databricks.sdk.core.http.FormRequest.wrapValuesInList(params)); + req.withHeaders(headers); + try { + ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); + OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); + if (resp.getErrorCode() != null) { + throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); + } + java.time.LocalDateTime expiry = + java.time.LocalDateTime.now() + .plus(resp.getExpiresIn(), java.time.temporal.ChronoUnit.SECONDS); + return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); + } catch (Exception e) { + throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); + } + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 96dc116c2..09cea6e86 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -77,7 +77,7 @@ public static Map addWorkspaceResourceId( } public static Map addSpManagementToken( - RefreshableTokenSource tokenSource, Map headers) { + TokenSource tokenSource, Map headers) { headers.put("X-Databricks-Azure-SP-Management-Token", tokenSource.getToken().getAccessToken()); return headers; } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java similarity index 64% rename from databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java rename to databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java index c33e7fda5..ba9a41753 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java @@ -11,7 +11,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -public class RefreshableTokenSourceTest { +public class CachedTokenSourceTest { private static final String TOKEN_TYPE = "Bearer"; private static final String INITIAL_TOKEN = "initial-token"; private static final String REFRESH_TOKEN = "refreshed-token"; @@ -46,10 +46,10 @@ void testAsyncRefreshParametrized( new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); CountDownLatch refreshCalled = new CountDownLatch(1); - RefreshableTokenSource source = - new RefreshableTokenSource(initialToken) { + TokenSource tokenSource = + new TokenSource() { @Override - protected Token refresh() { + public Token getToken() { refreshCalled.countDown(); try { Thread.sleep(500); @@ -58,7 +58,13 @@ protected Token refresh() { } return refreshedToken; } - }.withAsyncRefresh(asyncEnabled); + }; + + CachedTokenSource source = + new CachedTokenSource.Builder(tokenSource) + .withAsyncEnabled(asyncEnabled) + .withToken(initialToken) + .build(); Token token = source.getToken(); @@ -67,74 +73,56 @@ protected Token refresh() { assertEquals(expectedToken, token.getAccessToken(), "Token value did not match expected"); } - /** - * This test verifies that if an asynchronous token refresh fails, the next refresh attempt is - * forced to be synchronous. It ensures that after an async failure, the system does not - * repeatedly attempt async refreshes while the token is stale, and only performs a synchronous - * refresh when the token is expired. After a successful sync refresh, async refreshes resume as - * normal. - */ @Test void testAsyncRefreshFailureFallback() throws Exception { - Token staleToken = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); - - class TestSource extends RefreshableTokenSource { + class MutableTokenSource implements TokenSource { int refreshCallCount = 0; boolean isFirstRefresh = true; - - TestSource(Token token) { - super(token); - } + Token token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); @Override - protected Token refresh() { + public Token getToken() { refreshCallCount++; if (isFirstRefresh) { isFirstRefresh = false; throw new RuntimeException("Simulated async failure"); } - return new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); + token = new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); + return token; } } - TestSource source = new TestSource(staleToken); - source.withAsyncRefresh(true); + MutableTokenSource mutableTokenSource = new MutableTokenSource(); + + CachedTokenSource source = + new CachedTokenSource.Builder(mutableTokenSource) + .withAsyncEnabled(true) + .withToken(mutableTokenSource.token) + .build(); // First call triggers async refresh, which fails source.getToken(); - Thread.sleep(300); - assertEquals( - 1, source.refreshCallCount, "refresh() should have been called once (async, failed)"); + assertEquals(1, mutableTokenSource.refreshCallCount); // Token is still stale, so next call should NOT trigger another refresh since the last refresh // failed source.getToken(); - Thread.sleep(200); - assertEquals( - 1, - source.refreshCallCount, - "refresh() should NOT be called again while stale after async failure"); + assertEquals(1, mutableTokenSource.refreshCallCount); // Advance the clock so the token is now expired - source.token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().minusMinutes(1)); + mutableTokenSource.token = + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().minusMinutes(1)); // Now getToken() should call refresh synchronously and return the refreshed token - Token token = source.getToken(); - assertEquals( - REFRESH_TOKEN, - token.getAccessToken(), - "Should return the refreshed token after sync refresh"); - assertEquals( - 2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); + Token token; + token = source.getToken(); + assertEquals(REFRESH_TOKEN, token.getAccessToken()); + assertEquals(2, mutableTokenSource.refreshCallCount); // Make the token stale again and trigger async refresh since the last sync refresh succeeded - source.token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); + mutableTokenSource.token = + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); source.getToken(); - Thread.sleep(300); - assertEquals( - 3, - source.refreshCallCount, - "refresh() should have been called again asynchronously after making token stale"); + assertEquals(3, mutableTokenSource.refreshCallCount); } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java index 1714b731c..e8f6faa2f 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java @@ -186,7 +186,7 @@ void clientCredentials() throws IOException { .withClientSecret("abc") .withTokenUrl("https://tokenUrl") .build(); - Token token = clientCredentials.refresh(); + Token token = clientCredentials.getToken(); assertEquals("accessTokenFromServer", token.getAccessToken()); assertEquals("refreshTokenFromServer", token.getRefreshToken()); } @@ -212,7 +212,7 @@ void sessionCredentials() throws IOException { "originalRefreshToken", LocalDateTime.MAX)) .build(); - Token token = sessionCredentials.refresh(); + Token token = sessionCredentials.getToken(); // We check that we are actually getting the token from server response (that is defined // above) rather than what was given while creating session credentials From 0ce06d8a0f6c096e3e2a88a304c073ff70d44214 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 6 Jun 2025 08:58:29 +0000 Subject: [PATCH 20/49] Revert "Save progress" This reverts commit 8705ff53e7017d2212acb839edc2efdec6683410. --- .../databricks/sdk/core/CliTokenSource.java | 6 +- .../databricks/sdk/core/DatabricksConfig.java | 8 +- ...reServicePrincipalCredentialsProvider.java | 6 +- .../sdk/core/oauth/CachedTokenSource.java | 220 -------------- .../sdk/core/oauth/ClientCredentials.java | 7 +- .../databricks/sdk/core/oauth/Consent.java | 2 +- .../oauth/DatabricksOAuthTokenSource.java | 4 +- .../sdk/core/oauth/EndpointTokenSource.java | 4 +- .../ExternalBrowserCredentialsProvider.java | 2 +- .../sdk/core/oauth/OidcTokenSource.java | 5 +- .../core/oauth/RefreshableTokenSource.java | 287 +++++++++++++++++- .../sdk/core/oauth/SessionCredentials.java | 10 +- .../sdk/core/oauth/TokenEndpointClient.java | 60 ---- .../databricks/sdk/core/utils/AzureUtils.java | 2 +- ...xternalBrowserCredentialsProviderTest.java | 4 +- ...t.java => RefreshableTokenSourceTest.java} | 80 ++--- 16 files changed, 359 insertions(+), 348 deletions(-) delete mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java rename databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/{CachedTokenSourceTest.java => RefreshableTokenSourceTest.java} (64%) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 2eb5c13f7..8a7328904 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -1,7 +1,7 @@ package com.databricks.sdk.core; +import com.databricks.sdk.core.oauth.RefreshableTokenSource; import com.databricks.sdk.core.oauth.Token; -import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Environment; import com.databricks.sdk.core.utils.OSUtils; import com.fasterxml.jackson.databind.JsonNode; @@ -15,7 +15,7 @@ import java.util.List; import org.apache.commons.io.IOUtils; -public class CliTokenSource implements TokenSource { +public class CliTokenSource extends RefreshableTokenSource { private List cmd; private String tokenTypeField; private String accessTokenField; @@ -64,7 +64,7 @@ private String getProcessStream(InputStream stream) throws IOException { } @Override - public Token getToken() { + protected Token refresh() { try { ProcessBuilder processBuilder = new ProcessBuilder(cmd); processBuilder.environment().putAll(env.getEnv()); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index df16ebae3..de6548982 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -410,17 +410,13 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) { return this; } - /** - * @deprecated Use {@link #getAzureUseMsi()} instead. - */ + /** @deprecated Use {@link #getAzureUseMsi()} instead. */ @Deprecated() public boolean getAzureUseMSI() { return azureUseMsi; } - /** - * @deprecated Use {@link #setAzureUseMsi(boolean)} instead. - */ + /** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */ @Deprecated public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) { this.azureUseMsi = azureUseMsi; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index d10fefd94..c7c7bb672 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -28,8 +28,8 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { } AzureUtils.ensureHostPresent( config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor); - TokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - TokenSource cloud = + RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + RefreshableTokenSource cloud = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); return OAuthHeaderFactory.fromSuppliers( @@ -55,7 +55,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { * @return A RefreshableTokenSource instance capable of fetching OAuth tokens for the specified * Azure resource. */ - private static TokenSource tokenSourceFor(DatabricksConfig config, String resource) { + private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) { String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; Map endpointParams = new HashMap<>(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java deleted file mode 100644 index 72aa38f74..000000000 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java +++ /dev/null @@ -1,220 +0,0 @@ -package com.databricks.sdk.core.oauth; - -import com.databricks.sdk.core.utils.ClockSupplier; -import com.databricks.sdk.core.utils.SystemClockSupplier; -import java.time.Duration; -import java.time.LocalDateTime; -import java.util.concurrent.CompletableFuture; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * An OAuth TokenSource which can be refreshed. - * - *

This class supports both synchronous and asynchronous token refresh. When async is enabled, - * stale tokens will trigger a background refresh, while expired tokens will block until a new token - * is fetched. - */ -public class CachedTokenSource implements TokenSource { - - /** - * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: - * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: - * Token has expired and must be refreshed using a blocking call. - */ - private enum TokenState { - FRESH, - STALE, - EXPIRED - } - - private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); - // Default duration before expiry to consider a token as 'stale'. - private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); - private static final Duration DEFAULT_EXPIRY_BUFFER = Duration.ofSeconds(40); - - private final TokenSource tokenSource; - private final boolean asyncEnabled; - private final Duration staleDuration; - private final Duration expiryBuffer; - private final ClockSupplier clockSupplier; - - // The current OAuth token. May be null if not yet fetched. - private volatile Token token; - // Whether a refresh is currently in progress (for async refresh). - private volatile boolean refreshInProgress = false; - // Whether the last refresh attempt succeeded. - private volatile boolean lastRefreshSucceeded = true; - - private CachedTokenSource(Builder builder) { - this.tokenSource = builder.tokenSource; - this.asyncEnabled = builder.asyncEnabled; - this.staleDuration = builder.staleDuration; - this.expiryBuffer = builder.expiryBuffer; - this.clockSupplier = builder.clockSupplier; - this.token = builder.token; - } - - public static class Builder { - private final TokenSource tokenSource; - private Token token; - private boolean asyncEnabled = false; - private Duration staleDuration = DEFAULT_STALE_DURATION; - private Duration expiryBuffer = DEFAULT_EXPIRY_BUFFER; - private ClockSupplier clockSupplier = new SystemClockSupplier(); - - public Builder(TokenSource tokenSource) { - this.tokenSource = tokenSource; - } - - public Builder withToken(Token token) { - this.token = token; - return this; - } - - public Builder withAsyncEnabled(boolean asyncEnabled) { - this.asyncEnabled = asyncEnabled; - return this; - } - - public Builder withStaleDuration(Duration staleDuration) { - this.staleDuration = staleDuration; - return this; - } - - public Builder withExpiryBuffer(Duration expiryBuffer) { - this.expiryBuffer = expiryBuffer; - return this; - } - - public Builder withClockSupplier(ClockSupplier clockSupplier) { - this.clockSupplier = clockSupplier; - return this; - } - - public CachedTokenSource build() { - return new CachedTokenSource(this); - } - } - - /** - * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a - * stale token while a refresh is in progress. - * - *

This method may throw an exception if the token cannot be refreshed, depending on the - * implementation of {@link #refresh()}. - * - * @return The current valid token - */ - public Token getToken() { - if (!asyncEnabled) { - return getTokenBlocking(); - } - return getTokenAsync(); - } - - /** - * Determine the state of the current token (fresh, stale, or expired). - * - * @return The token state - */ - protected TokenState getTokenState(Token t) { - if (t == null) { - return TokenState.EXPIRED; - } - Duration lifeTime = - Duration.between(LocalDateTime.now(clockSupplier.getClock()), t.getExpiry()); - if (lifeTime.compareTo(expiryBuffer) <= 0) { - return TokenState.EXPIRED; - } - if (lifeTime.compareTo(staleDuration) <= 0) { - return TokenState.STALE; - } - return TokenState.FRESH; - } - - /** - * Get the current token, blocking to refresh if expired. - * - *

This method may throw an exception if the token cannot be refreshed, depending on the - * implementation of {@link #refresh()}. - * - * @return The current valid token - */ - protected Token getTokenBlocking() { - // Use double-checked locking to minimize synchronization overhead on reads: - // 1. Check if the token is expired without locking. - // 2. If expired, synchronize and check again (another thread may have refreshed it). - // 3. If still expired, perform the refresh. - if (getTokenState(token) != TokenState.EXPIRED) { - return token; - } - synchronized (this) { - if (getTokenState(token) != TokenState.EXPIRED) { - return token; - } - lastRefreshSucceeded = false; - try { - token = tokenSource.getToken(); - } catch (Exception e) { - logger.error("Failed to refresh token synchronously", e); - throw e; - } - lastRefreshSucceeded = true; - return token; - } - } - - /** - * Get the current token, possibly triggering an async refresh if stale. If the token is expired, - * blocks to refresh. - * - *

This method may throw an exception if the token cannot be refreshed, depending on the - * implementation of {@link #refresh()}. - * - * @return The current valid or stale token - */ - protected Token getTokenAsync() { - Token currentToken = token; - - switch (getTokenState(currentToken)) { - case FRESH: - return currentToken; - case STALE: - triggerAsyncRefresh(); - return currentToken; - case EXPIRED: - return getTokenBlocking(); - default: - throw new IllegalStateException("Invalid token state."); - } - } - - /** - * Trigger an asynchronous refresh of the token if not already in progress and last refresh - * succeeded. - */ - protected synchronized void triggerAsyncRefresh() { - // Check token state to avoid triggering a refresh if another thread has already refreshed it - if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { - refreshInProgress = true; - CompletableFuture.runAsync( - () -> { - try { - // Attempt to refresh the token in the background - Token newToken = tokenSource.getToken(); - synchronized (this) { - token = newToken; - refreshInProgress = false; - } - } catch (Exception e) { - synchronized (this) { - lastRefreshSucceeded = false; - refreshInProgress = false; - logger.error("Async token refresh failed", e); - } - } - }); - } - } -} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java index 8cee3ef29..1c4b7d6de 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java @@ -13,7 +13,7 @@ * support all OAuth endpoints, authentication parameters can be passed in the request body or in * the Authorization header. */ -public class ClientCredentials implements TokenSource { +public class ClientCredentials extends RefreshableTokenSource { public static class Builder { private String clientId; private String clientSecret; @@ -97,7 +97,7 @@ private ClientCredentials( } @Override - public Token getToken() { + protected Token refresh() { Map params = new HashMap<>(); params.put("grant_type", "client_credentials"); if (scopes != null) { @@ -106,7 +106,6 @@ public Token getToken() { if (endpointParamsSupplier != null) { params.putAll(endpointParamsSupplier.get()); } - return TokenEndpointClient.retrieveToken( - hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); + return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java index 68dc6f176..77045df97 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java @@ -313,7 +313,7 @@ public SessionCredentials exchange(String code, String state) { headers.put("Origin", this.redirectUrl); } Token token = - TokenEndpointClient.retrieveToken( + RefreshableTokenSource.retrieveToken( this.hc, this.clientId, this.clientSecret, diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index ac76234a0..f16ae2aed 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -14,7 +14,7 @@ * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ -public class DatabricksOAuthTokenSource implements TokenSource { +public class DatabricksOAuthTokenSource extends RefreshableTokenSource { private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class); /** OAuth client ID used for token exchange. */ @@ -128,7 +128,7 @@ public DatabricksOAuthTokenSource build() { * @throws NullPointerException when any of the required parameters are null. */ @Override - public Token getToken() { + public Token refresh() { Objects.requireNonNull(clientId, "ClientID cannot be null"); Objects.requireNonNull(host, "Host cannot be null"); Objects.requireNonNull(endpoints, "Endpoints cannot be null"); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index 7bcddacdd..3ca75c441 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -13,7 +13,7 @@ * Represents a token source that exchanges a control plane token for an endpoint-specific dataplane * token. It utilizes an underlying {@link TokenSource} to obtain the initial control plane token. */ -public class EndpointTokenSource implements TokenSource { +public class EndpointTokenSource extends RefreshableTokenSource { private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; private static final String GRANT_TYPE_PARAM = "grant_type"; @@ -67,7 +67,7 @@ public EndpointTokenSource( * @throws NullPointerException if any of the parameters are null. */ @Override - public Token getToken() { + protected Token refresh() { Token cpToken = cpTokenSource.getToken(); Map params = new HashMap<>(); params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index 494ad69cd..7bae60022 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -79,7 +79,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { .build(); LOGGER.debug("Using cached token, will immediately refresh"); - cachedCreds.token = cachedCreds.getToken(); + cachedCreds.token = cachedCreds.refresh(); return cachedCreds.configure(config); } catch (Exception e) { // If token refresh fails, log and continue to browser auth diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java index 1739c5c6c..719544ebf 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java @@ -15,7 +15,7 @@ * protocol. It communicates with an OAuth server to request access tokens using the client * credentials grant type instead of a client secret. */ -class OidcTokenSource implements TokenSource { +class OidcTokenSource extends RefreshableTokenSource { private final HttpClient httpClient; private final String tokenUrl; @@ -58,8 +58,7 @@ private static void putIfDefined( } } - @Override - public Token getToken() { + protected Token refresh() { Response rawResp; try { rawResp = httpClient.execute(new FormRequest(tokenUrl, params)); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index c0cf080ce..c62dccad7 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -1,5 +1,22 @@ package com.databricks.sdk.core.oauth; +import com.databricks.sdk.core.ApiClient; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.utils.ClockSupplier; +import com.databricks.sdk.core.utils.SystemClockSupplier; +import java.time.Duration; +import java.time.LocalDateTime; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.apache.http.HttpHeaders; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * An OAuth TokenSource which can be refreshed. * @@ -7,4 +24,272 @@ * stale tokens will trigger a background refresh, while expired tokens will block until a new token * is fetched. */ -public abstract class RefreshableTokenSource implements TokenSource {} +public abstract class RefreshableTokenSource implements TokenSource { + + /** + * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: + * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: + * Token has expired and must be refreshed using a blocking call. + */ + private enum TokenState { + FRESH, + STALE, + EXPIRED + } + + private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); + // Default duration before expiry to consider a token as 'stale'. + private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); + + // The current OAuth token. May be null if not yet fetched. + protected volatile Token token; + // Whether asynchronous refresh is enabled. + private boolean asyncEnabled = false; + // Duration before expiry to consider a token as 'stale'. + private Duration staleDuration = DEFAULT_STALE_DURATION; + // Additional buffer before expiry to consider a token as expired. + private Duration expiryBuffer = Duration.ofSeconds(40); + // Whether a refresh is currently in progress (for async refresh). + private boolean refreshInProgress = false; + // Whether the last refresh attempt succeeded. + private boolean lastRefreshSucceeded = true; + // Clock supplier for current time, for testing purposes. + private ClockSupplier clockSupplier = new SystemClockSupplier(); + + /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ + public RefreshableTokenSource() {} + + /** + * Constructor with initial token. + * + * @param token The initial token to use. + */ + public RefreshableTokenSource(Token token) { + this.token = token; + } + + /** + * Set the clock supplier for current time. + * + *

Experimental: This method may change or be removed in future releases. + * + * @param clockSupplier The clock supplier to use. + * @return this instance for chaining + */ + public RefreshableTokenSource withClockSupplier(ClockSupplier clockSupplier) { + this.clockSupplier = clockSupplier; + return this; + } + + /** + * Enable or disable asynchronous token refresh. + * + *

Experimental: This method may change or be removed in future releases. + * + * @param enabled true to enable async refresh, false to disable + * @return this instance for chaining + */ + public RefreshableTokenSource withAsyncRefresh(boolean enabled) { + this.asyncEnabled = enabled; + return this; + } + + /** + * Set the expiry buffer. If the token's lifetime is less than this buffer, it is considered + * expired. + * + *

Experimental: This method may change or be removed in future releases. + * + * @param buffer the expiry buffer duration + * @return this instance for chaining + */ + public RefreshableTokenSource withExpiryBuffer(Duration buffer) { + this.expiryBuffer = buffer; + return this; + } + + /** + * Refresh the OAuth token. Subclasses must implement this to define how the token is refreshed. + * + *

This method may throw an exception if the token cannot be refreshed. The specific exception + * type depends on the implementation. + * + * @return The newly refreshed Token. + */ + protected abstract Token refresh(); + + /** + * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a + * stale token while a refresh is in progress. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid token + */ + public Token getToken() { + if (!asyncEnabled) { + return getTokenBlocking(); + } + return getTokenAsync(); + } + + /** + * Determine the state of the current token (fresh, stale, or expired). + * + * @return The token state + */ + protected TokenState getTokenState(Token t) { + if (t == null) { + return TokenState.EXPIRED; + } + Duration lifeTime = + Duration.between(LocalDateTime.now(clockSupplier.getClock()), t.getExpiry()); + if (lifeTime.compareTo(expiryBuffer) <= 0) { + return TokenState.EXPIRED; + } + if (lifeTime.compareTo(staleDuration) <= 0) { + return TokenState.STALE; + } + return TokenState.FRESH; + } + + /** + * Get the current token, blocking to refresh if expired. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid token + */ + protected Token getTokenBlocking() { + // Use double-checked locking to minimize synchronization overhead on reads: + // 1. Check if the token is expired without locking. + // 2. If expired, synchronize and check again (another thread may have refreshed it). + // 3. If still expired, perform the refresh. + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + synchronized (this) { + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + lastRefreshSucceeded = false; + try { + token = refresh(); + } catch (Exception e) { + logger.error("Failed to refresh token synchronously", e); + throw e; + } + lastRefreshSucceeded = true; + return token; + } + } + + /** + * Get the current token, possibly triggering an async refresh if stale. If the token is expired, + * blocks to refresh. + * + *

This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid or stale token + */ + protected Token getTokenAsync() { + Token currentToken = token; + + switch (getTokenState(currentToken)) { + case FRESH: + return currentToken; + case STALE: + triggerAsyncRefresh(); + return currentToken; + case EXPIRED: + return getTokenBlocking(); + default: + throw new IllegalStateException("Invalid token state."); + } + } + + /** + * Trigger an asynchronous refresh of the token if not already in progress and last refresh + * succeeded. + */ + protected synchronized void triggerAsyncRefresh() { + // Check token state to avoid triggering a refresh if another thread has already refreshed it + if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { + refreshInProgress = true; + CompletableFuture.runAsync( + () -> { + try { + // Attempt to refresh the token in the background + Token newToken = refresh(); + synchronized (this) { + token = newToken; + refreshInProgress = false; + } + } catch (Exception e) { + synchronized (this) { + lastRefreshSucceeded = false; + refreshInProgress = false; + logger.error("Async token refresh failed", e); + } + } + }); + } + } + + /** + * Helper method implementing OAuth token refresh. + * + * @param hc The HTTP client to use for the request. + * @param clientId The client ID to authenticate with. + * @param clientSecret The client secret to authenticate with. + * @param tokenUrl The authorization URL for fetching tokens. + * @param params Additional request parameters. + * @param headers Additional headers. + * @param position The position of the authentication parameters in the request. + * @return The newly fetched Token. + * @throws DatabricksException if the refresh fails + * @throws IllegalArgumentException if the OAuth response contains an error + */ + protected static Token retrieveToken( + HttpClient hc, + String clientId, + String clientSecret, + String tokenUrl, + Map params, + Map headers, + AuthParameterPosition position) { + switch (position) { + case BODY: + if (clientId != null) { + params.put("client_id", clientId); + } + if (clientSecret != null) { + params.put("client_secret", clientSecret); + } + break; + case HEADER: + String authHeaderValue = + "Basic " + + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); + headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue); + break; + } + headers.put("Content-Type", "application/x-www-form-urlencoded"); + Request req = new Request("POST", tokenUrl, FormRequest.wrapValuesInList(params)); + req.withHeaders(headers); + try { + ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); + OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); + if (resp.getErrorCode() != null) { + throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); + } + LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS); + return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); + } catch (Exception e) { + throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); + } + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java index 3bde86994..4d2d512e3 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java @@ -17,7 +17,8 @@ * requests to an API, and a long-lived refresh token, which can be used to fetch new access tokens. * Calling refresh() uses the refresh token to retrieve a new access token to authenticate to APIs. */ -public class SessionCredentials implements TokenSource, CredentialsProvider, Serializable { +public class SessionCredentials extends RefreshableTokenSource + implements CredentialsProvider, Serializable { private static final long serialVersionUID = 3083941540130596650L; private static final Logger LOGGER = LoggerFactory.getLogger(SessionCredentials.class); @@ -86,10 +87,9 @@ public SessionCredentials build() { private final String clientId; private final String clientSecret; private final TokenCache tokenCache; - protected Token token; private SessionCredentials(Builder b) { - this.token = b.token; + super(b.token); this.hc = b.hc; this.tokenUrl = b.tokenUrl; this.redirectUrl = b.redirectUrl; @@ -99,7 +99,7 @@ private SessionCredentials(Builder b) { } @Override - public Token getToken() { + protected Token refresh() { if (this.token == null) { throw new DatabricksException("oauth2: token is not set"); } @@ -118,7 +118,7 @@ public Token getToken() { headers.put("Origin", redirectUrl); } Token newToken = - TokenEndpointClient.retrieveToken( + retrieveToken( hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY); // Save the refreshed token directly to cache diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java index ace0c9314..69883dd24 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -1,6 +1,5 @@ package com.databricks.sdk.core.oauth; -import com.databricks.sdk.core.ApiClient; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; @@ -89,63 +88,4 @@ public static OAuthResponse requestToken( LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); return response; } - - /** - * Helper method implementing OAuth token refresh. - * - * @param hc The HTTP client to use for the request. - * @param clientId The client ID to authenticate with. - * @param clientSecret The client secret to authenticate with. - * @param tokenUrl The authorization URL for fetching tokens. - * @param params Additional request parameters. - * @param headers Additional headers. - * @param position The position of the authentication parameters in the request. - * @return The newly fetched Token. - * @throws DatabricksException if the refresh fails - * @throws IllegalArgumentException if the OAuth response contains an error - */ - public static Token retrieveToken( - HttpClient hc, - String clientId, - String clientSecret, - String tokenUrl, - Map params, - Map headers, - AuthParameterPosition position) { - switch (position) { - case BODY: - if (clientId != null) { - params.put("client_id", clientId); - } - if (clientSecret != null) { - params.put("client_secret", clientSecret); - } - break; - case HEADER: - String authHeaderValue = - "Basic " - + java.util.Base64.getEncoder() - .encodeToString((clientId + ":" + clientSecret).getBytes()); - headers.put(org.apache.http.HttpHeaders.AUTHORIZATION, authHeaderValue); - break; - } - headers.put("Content-Type", "application/x-www-form-urlencoded"); - com.databricks.sdk.core.http.Request req = - new com.databricks.sdk.core.http.Request( - "POST", tokenUrl, com.databricks.sdk.core.http.FormRequest.wrapValuesInList(params)); - req.withHeaders(headers); - try { - ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); - OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); - if (resp.getErrorCode() != null) { - throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); - } - java.time.LocalDateTime expiry = - java.time.LocalDateTime.now() - .plus(resp.getExpiresIn(), java.time.temporal.ChronoUnit.SECONDS); - return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); - } catch (Exception e) { - throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); - } - } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 09cea6e86..96dc116c2 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -77,7 +77,7 @@ public static Map addWorkspaceResourceId( } public static Map addSpManagementToken( - TokenSource tokenSource, Map headers) { + RefreshableTokenSource tokenSource, Map headers) { headers.put("X-Databricks-Azure-SP-Management-Token", tokenSource.getToken().getAccessToken()); return headers; } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java index e8f6faa2f..1714b731c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java @@ -186,7 +186,7 @@ void clientCredentials() throws IOException { .withClientSecret("abc") .withTokenUrl("https://tokenUrl") .build(); - Token token = clientCredentials.getToken(); + Token token = clientCredentials.refresh(); assertEquals("accessTokenFromServer", token.getAccessToken()); assertEquals("refreshTokenFromServer", token.getRefreshToken()); } @@ -212,7 +212,7 @@ void sessionCredentials() throws IOException { "originalRefreshToken", LocalDateTime.MAX)) .build(); - Token token = sessionCredentials.getToken(); + Token token = sessionCredentials.refresh(); // We check that we are actually getting the token from server response (that is defined // above) rather than what was given while creating session credentials diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java similarity index 64% rename from databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java rename to databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index ba9a41753..c33e7fda5 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -11,7 +11,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -public class CachedTokenSourceTest { +public class RefreshableTokenSourceTest { private static final String TOKEN_TYPE = "Bearer"; private static final String INITIAL_TOKEN = "initial-token"; private static final String REFRESH_TOKEN = "refreshed-token"; @@ -46,10 +46,10 @@ void testAsyncRefreshParametrized( new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); CountDownLatch refreshCalled = new CountDownLatch(1); - TokenSource tokenSource = - new TokenSource() { + RefreshableTokenSource source = + new RefreshableTokenSource(initialToken) { @Override - public Token getToken() { + protected Token refresh() { refreshCalled.countDown(); try { Thread.sleep(500); @@ -58,13 +58,7 @@ public Token getToken() { } return refreshedToken; } - }; - - CachedTokenSource source = - new CachedTokenSource.Builder(tokenSource) - .withAsyncEnabled(asyncEnabled) - .withToken(initialToken) - .build(); + }.withAsyncRefresh(asyncEnabled); Token token = source.getToken(); @@ -73,56 +67,74 @@ public Token getToken() { assertEquals(expectedToken, token.getAccessToken(), "Token value did not match expected"); } + /** + * This test verifies that if an asynchronous token refresh fails, the next refresh attempt is + * forced to be synchronous. It ensures that after an async failure, the system does not + * repeatedly attempt async refreshes while the token is stale, and only performs a synchronous + * refresh when the token is expired. After a successful sync refresh, async refreshes resume as + * normal. + */ @Test void testAsyncRefreshFailureFallback() throws Exception { - class MutableTokenSource implements TokenSource { + Token staleToken = + new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); + + class TestSource extends RefreshableTokenSource { int refreshCallCount = 0; boolean isFirstRefresh = true; - Token token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); + + TestSource(Token token) { + super(token); + } @Override - public Token getToken() { + protected Token refresh() { refreshCallCount++; if (isFirstRefresh) { isFirstRefresh = false; throw new RuntimeException("Simulated async failure"); } - token = new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); - return token; + return new Token(REFRESH_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(10)); } } - MutableTokenSource mutableTokenSource = new MutableTokenSource(); - - CachedTokenSource source = - new CachedTokenSource.Builder(mutableTokenSource) - .withAsyncEnabled(true) - .withToken(mutableTokenSource.token) - .build(); + TestSource source = new TestSource(staleToken); + source.withAsyncRefresh(true); // First call triggers async refresh, which fails source.getToken(); - assertEquals(1, mutableTokenSource.refreshCallCount); + Thread.sleep(300); + assertEquals( + 1, source.refreshCallCount, "refresh() should have been called once (async, failed)"); // Token is still stale, so next call should NOT trigger another refresh since the last refresh // failed source.getToken(); - assertEquals(1, mutableTokenSource.refreshCallCount); + Thread.sleep(200); + assertEquals( + 1, + source.refreshCallCount, + "refresh() should NOT be called again while stale after async failure"); // Advance the clock so the token is now expired - mutableTokenSource.token = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().minusMinutes(1)); + source.token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().minusMinutes(1)); // Now getToken() should call refresh synchronously and return the refreshed token - Token token; - token = source.getToken(); - assertEquals(REFRESH_TOKEN, token.getAccessToken()); - assertEquals(2, mutableTokenSource.refreshCallCount); + Token token = source.getToken(); + assertEquals( + REFRESH_TOKEN, + token.getAccessToken(), + "Should return the refreshed token after sync refresh"); + assertEquals( + 2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); // Make the token stale again and trigger async refresh since the last sync refresh succeeded - mutableTokenSource.token = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); + source.token = new Token(INITIAL_TOKEN, TOKEN_TYPE, null, LocalDateTime.now().plusMinutes(2)); source.getToken(); - assertEquals(3, mutableTokenSource.refreshCallCount); + Thread.sleep(300); + assertEquals( + 3, + source.refreshCallCount, + "refresh() should have been called again asynchronously after making token stale"); } } From fdc50effa6c3d0c22bffe247e55ee3bb2acb18c8 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 6 Jun 2025 11:00:55 +0000 Subject: [PATCH 21/49] Removed redundant date formattters --- .../databricks/sdk/core/CliTokenSource.java | 42 ++-------- .../sdk/core/CliTokenSourceTest.java | 79 ++++++++----------- .../src/test/resources/testdata/az | 8 +- 3 files changed, 45 insertions(+), 84 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 2e176c210..0e68dc7fc 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -9,11 +9,8 @@ import java.io.IOException; import java.io.InputStream; import java.time.Instant; -import java.time.LocalDateTime; -import java.time.ZoneOffset; -import java.time.format.DateTimeFormatter; +import java.time.OffsetDateTime; import java.time.format.DateTimeParseException; -import java.util.Arrays; import java.util.List; import org.apache.commons.io.IOUtils; @@ -39,44 +36,15 @@ public CliTokenSource( } /** - * Parses an expiry time string and returns the corresponding {@link Instant}. + * Parses an expiry string in RFC 3339/ISO 8601 format (with or without offset) and returns the + * corresponding {@link Instant}. Any specified time zone or offset is converted to UTC. * - *

The expiry time string is always in UTC. Any time zone or offset information present in the - * input is ignored. - * - *

The method attempts to parse the input using several common date-time formats, including - * ISO-8601 and patterns with varying sub-second precision. - * - * @param expiry the expiry time string to parse, which must represent a UTC time + * @param expiry expiry time string in RFC 3339/ISO 8601 format * @return the parsed {@link Instant} * @throws DateTimeParseException if the input string cannot be parsed */ static Instant parseExpiry(String expiry) { - DateTimeParseException lastException = null; - // Try to parse the expiry as an ISO-8601 string in UTC first - try { - return Instant.parse(expiry); - } catch (DateTimeParseException e) { - lastException = e; - } - - String multiplePrecisionPattern = - "[SSSSSSSSS][SSSSSSSS][SSSSSSS][SSSSSS][SSSSS][SSSS][SSS][SS][S]"; - List datePatterns = - Arrays.asList( - "yyyy-MM-dd HH:mm:ss", - "yyyy-MM-dd HH:mm:ss." + multiplePrecisionPattern, - "yyyy-MM-dd'T'HH:mm:ss." + multiplePrecisionPattern + "XXX"); - for (String pattern : datePatterns) { - try { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern(pattern); - LocalDateTime dateTime = LocalDateTime.parse(expiry, formatter); - return dateTime.atZone(ZoneOffset.UTC).toInstant(); - } catch (DateTimeParseException e) { - lastException = e; - } - } - throw lastException; + return OffsetDateTime.parse(expiry).toInstant(); } private String getProcessStream(InputStream stream) throws IOException { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index abe609b01..cc177881c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -1,58 +1,49 @@ package com.databricks.sdk.core; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.time.Instant; import java.time.format.DateTimeParseException; -import org.junit.jupiter.api.Test; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class CliTokenSourceTest { - @Test - public void testParseExpiryWithoutTruncate() { - Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612218Z"); - assertEquals(Instant.parse("2023-07-17T09:02:22.330612218Z"), parsedInstant); + private static Stream expiryProvider() { + return Stream.of( + Arguments.of( + "2023-07-17T09:02:22.330612218Z", "2023-07-17T09:02:22.330612218Z", "9-digit nanos"), + Arguments.of( + "2023-07-17T09:02:22.33061221Z", "2023-07-17T09:02:22.330612210Z", "8-digit nanos"), + Arguments.of( + "2023-07-17T09:02:22.330612Z", "2023-07-17T09:02:22.330612000Z", "6-digit nanos"), + Arguments.of( + "2023-07-17T10:02:22.330612218+01:00", + "2023-07-17T09:02:22.330612218Z", + "+01:00 offset, 9-digit nanos"), + Arguments.of( + "2023-07-17T04:02:22.330612218-05:00", + "2023-07-17T09:02:22.330612218Z", + "-05:00 offset, 9-digit nanos"), + Arguments.of( + "2023-07-17T10:02:22.330612+01:00", + "2023-07-17T09:02:22.330612000Z", + "+01:00 offset, 6-digit nanos"), + Arguments.of("2023-07-17T09:02:22.33061221987Z", null, "Invalid: >9 nanos"), + Arguments.of("17-07-2023 09:02:22", null, "Invalid: non-ISO8601 format")); } - @Test - public void testParseExpiryWithTruncate() { - Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17T09:02:22.33061221Z"); - assertEquals(Instant.parse("2023-07-17T09:02:22.330612210Z"), parsedInstant); - } - - @Test - public void testParseExpiryWithTruncateAndLessNanoSecondDigits() { - Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612Z"); - assertEquals(Instant.parse("2023-07-17T09:02:22.330612000Z"), parsedInstant); - } - - @Test - public void testParseExpiryWithMoreThanNineNanoSecondDigits() { - try { - CliTokenSource.parseExpiry("2023-07-17T09:02:22.33061221987Z"); - } catch (DateTimeParseException e) { - assert (e.getMessage().contains("could not be parsed")); - } - } - - @Test - public void testParseExpiryWithSpaceFormat() { - Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17 09:02:22"); - assertEquals(Instant.parse("2023-07-17T09:02:22Z"), parsedInstant); - } - - @Test - public void testParseExpiryWithSpaceFormatAndMillis() { - Instant parsedInstant = CliTokenSource.parseExpiry("2023-07-17 09:02:22.123"); - assertEquals(Instant.parse("2023-07-17T09:02:22.123Z"), parsedInstant); - } - - @Test - public void testParseExpiryWithInvalidFormat() { - try { - CliTokenSource.parseExpiry("17-07-2023 09:02:22"); - } catch (DateTimeParseException e) { - assert (e.getMessage().contains("could not be parsed")); + @ParameterizedTest(name = "{2}") + @MethodSource("expiryProvider") + public void testParseExpiry(String input, String expectedInstant, String description) { + if (expectedInstant == null) { + assertThrows(DateTimeParseException.class, () -> CliTokenSource.parseExpiry(input)); + } else { + Instant parsedInstant = CliTokenSource.parseExpiry(input); + assertEquals(Instant.parse(expectedInstant), parsedInstant); } } } diff --git a/databricks-sdk-java/src/test/resources/testdata/az b/databricks-sdk-java/src/test/resources/testdata/az index 29b824ed7..b24bddea9 100755 --- a/databricks-sdk-java/src/test/resources/testdata/az +++ b/databricks-sdk-java/src/test/resources/testdata/az @@ -22,14 +22,16 @@ for arg in "$@"; do fi done -# Macos -EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)" +# MacOS +EXP="$(/bin/date -v+${EXPIRE:=10S} +'%FT%T%z' 2>/dev/null)" if [ -z "${EXP}" ]; then # Linux EXPIRE=$(/bin/echo $EXPIRE | /bin/sed 's/S/seconds/') EXPIRE=$(/bin/echo $EXPIRE | /bin/sed 's/M/minutes/') - EXP=$(/bin/date --date=+${EXPIRE:=10seconds} +'%F %T') + EXP=$(/bin/date --date=+${EXPIRE:=10seconds} +'%FT%T%z') fi +# Insert colon in timezone offset for ISO 8601 compliance +EXP="${EXP:0:19}${EXP:19:3}:${EXP:22:2}" if [ -z "${TF_AAD_TOKEN}" ]; then TF_AAD_TOKEN="..." From 70934b27ca4769d241bd3d02ff02e1d1e58c13d5 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 6 Jun 2025 14:08:29 +0000 Subject: [PATCH 22/49] Change clock supplier to use UTC time --- .../java/com/databricks/sdk/core/utils/SystemClockSupplier.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java index edac0cd94..c79e6884d 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java @@ -5,6 +5,6 @@ public class SystemClockSupplier implements ClockSupplier { @Override public Clock getClock() { - return Clock.systemDefaultZone(); + return Clock.systemUTC(); } } From 1bd052f1a4630b6c343fa7b01e6a474c45a2f8e9 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Sat, 7 Jun 2025 18:03:50 +0000 Subject: [PATCH 23/49] Add support for space separated expiry strings --- .../databricks/sdk/core/CliTokenSource.java | 26 +++++++++++- .../sdk/core/CliTokenSourceTest.java | 40 ++++++++++++++----- .../src/test/resources/testdata/az | 10 ++--- 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 0e68dc7fc..9313a4b3f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -9,8 +9,12 @@ import java.io.IOException; import java.io.InputStream; import java.time.Instant; +import java.time.LocalDateTime; import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; +import java.util.Arrays; import java.util.List; import org.apache.commons.io.IOUtils; @@ -44,7 +48,27 @@ public CliTokenSource( * @throws DateTimeParseException if the input string cannot be parsed */ static Instant parseExpiry(String expiry) { - return OffsetDateTime.parse(expiry).toInstant(); + DateTimeParseException lastException = null; + try { + return OffsetDateTime.parse(expiry).toInstant(); + } catch (DateTimeParseException e) { + lastException = e; + } + + String multiplePrecisionPattern = + "[SSSSSSSSS][SSSSSSSS][SSSSSSS][SSSSSS][SSSSS][SSSS][SSS][SS][S]"; + List datePatterns = + Arrays.asList("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd HH:mm:ss." + multiplePrecisionPattern); + for (String pattern : datePatterns) { + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(pattern); + LocalDateTime dateTime = LocalDateTime.parse(expiry, formatter); + return dateTime.atZone(ZoneId.systemDefault()).toInstant(); + } catch (DateTimeParseException e) { + lastException = e; + } + } + throw lastException; } private String getProcessStream(InputStream stream) throws IOException { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index cc177881c..d6c8fc740 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -4,6 +4,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneId; import java.time.format.DateTimeParseException; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; @@ -15,35 +17,55 @@ public class CliTokenSourceTest { private static Stream expiryProvider() { return Stream.of( Arguments.of( - "2023-07-17T09:02:22.330612218Z", "2023-07-17T09:02:22.330612218Z", "9-digit nanos"), + "2023-07-17T09:02:22.330612218Z", + Instant.parse("2023-07-17T09:02:22.330612218Z"), + "9-digit nanos"), Arguments.of( - "2023-07-17T09:02:22.33061221Z", "2023-07-17T09:02:22.330612210Z", "8-digit nanos"), + "2023-07-17T09:02:22.33061221Z", + Instant.parse("2023-07-17T09:02:22.330612210Z"), + "8-digit nanos"), Arguments.of( - "2023-07-17T09:02:22.330612Z", "2023-07-17T09:02:22.330612000Z", "6-digit nanos"), + "2023-07-17T09:02:22.330612Z", + Instant.parse("2023-07-17T09:02:22.330612000Z"), + "6-digit nanos"), Arguments.of( "2023-07-17T10:02:22.330612218+01:00", - "2023-07-17T09:02:22.330612218Z", + Instant.parse("2023-07-17T09:02:22.330612218Z"), "+01:00 offset, 9-digit nanos"), Arguments.of( "2023-07-17T04:02:22.330612218-05:00", - "2023-07-17T09:02:22.330612218Z", + Instant.parse("2023-07-17T09:02:22.330612218Z"), "-05:00 offset, 9-digit nanos"), Arguments.of( "2023-07-17T10:02:22.330612+01:00", - "2023-07-17T09:02:22.330612000Z", + Instant.parse("2023-07-17T09:02:22.330612000Z"), "+01:00 offset, 6-digit nanos"), Arguments.of("2023-07-17T09:02:22.33061221987Z", null, "Invalid: >9 nanos"), - Arguments.of("17-07-2023 09:02:22", null, "Invalid: non-ISO8601 format")); + Arguments.of("17-07-2023 09:02:22", null, "Invalid date format"), + Arguments.of( + "2023-07-17 09:02:22.330612218", + LocalDateTime.parse("2023-07-17T09:02:22.330612218") + .atZone(ZoneId.systemDefault()) + .toInstant(), + "Space separator, 9-digit nanos"), + Arguments.of( + "2023-07-17 09:02:22.330612", + LocalDateTime.parse("2023-07-17T09:02:22.330612") + .atZone(ZoneId.systemDefault()) + .toInstant(), + "Space separator, 6-digit nanos"), + Arguments.of( + "2023-07-17 09:02:22.33061221987", null, "Space separator, Invalid: >9 nanos")); } @ParameterizedTest(name = "{2}") @MethodSource("expiryProvider") - public void testParseExpiry(String input, String expectedInstant, String description) { + public void testParseExpiry(String input, Instant expectedInstant, String description) { if (expectedInstant == null) { assertThrows(DateTimeParseException.class, () -> CliTokenSource.parseExpiry(input)); } else { Instant parsedInstant = CliTokenSource.parseExpiry(input); - assertEquals(Instant.parse(expectedInstant), parsedInstant); + assertEquals(expectedInstant, parsedInstant); } } } diff --git a/databricks-sdk-java/src/test/resources/testdata/az b/databricks-sdk-java/src/test/resources/testdata/az index b24bddea9..00997a4cf 100755 --- a/databricks-sdk-java/src/test/resources/testdata/az +++ b/databricks-sdk-java/src/test/resources/testdata/az @@ -22,16 +22,14 @@ for arg in "$@"; do fi done -# MacOS -EXP="$(/bin/date -v+${EXPIRE:=10S} +'%FT%T%z' 2>/dev/null)" +# Macos +EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)" if [ -z "${EXP}" ]; then # Linux EXPIRE=$(/bin/echo $EXPIRE | /bin/sed 's/S/seconds/') EXPIRE=$(/bin/echo $EXPIRE | /bin/sed 's/M/minutes/') - EXP=$(/bin/date --date=+${EXPIRE:=10seconds} +'%FT%T%z') + EXP=$(/bin/date --date=+${EXPIRE:=10seconds} +'%F %T') fi -# Insert colon in timezone offset for ISO 8601 compliance -EXP="${EXP:0:19}${EXP:19:3}:${EXP:22:2}" if [ -z "${TF_AAD_TOKEN}" ]; then TF_AAD_TOKEN="..." @@ -43,4 +41,4 @@ fi \"subscription\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\", \"tenant\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\", \"tokenType\": \"Bearer\" -}" +}" \ No newline at end of file From 408f3b465307952d7bad140f92d8a71ab1b78e2b Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Sat, 7 Jun 2025 18:09:32 +0000 Subject: [PATCH 24/49] revert test data --- databricks-sdk-java/src/test/resources/testdata/az | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/test/resources/testdata/az b/databricks-sdk-java/src/test/resources/testdata/az index 00997a4cf..29b824ed7 100755 --- a/databricks-sdk-java/src/test/resources/testdata/az +++ b/databricks-sdk-java/src/test/resources/testdata/az @@ -41,4 +41,4 @@ fi \"subscription\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\", \"tenant\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\", \"tokenType\": \"Bearer\" -}" \ No newline at end of file +}" From 14d2a8fa749b4e8fc11f8a7bbff888cfa09e8b68 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 10 Jun 2025 15:04:26 +0000 Subject: [PATCH 25/49] Update CilTokenSource --- .../databricks/sdk/core/CliTokenSource.java | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 0e68dc7fc..9313a4b3f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -9,8 +9,12 @@ import java.io.IOException; import java.io.InputStream; import java.time.Instant; +import java.time.LocalDateTime; import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; +import java.util.Arrays; import java.util.List; import org.apache.commons.io.IOUtils; @@ -44,7 +48,27 @@ public CliTokenSource( * @throws DateTimeParseException if the input string cannot be parsed */ static Instant parseExpiry(String expiry) { - return OffsetDateTime.parse(expiry).toInstant(); + DateTimeParseException lastException = null; + try { + return OffsetDateTime.parse(expiry).toInstant(); + } catch (DateTimeParseException e) { + lastException = e; + } + + String multiplePrecisionPattern = + "[SSSSSSSSS][SSSSSSSS][SSSSSSS][SSSSSS][SSSSS][SSSS][SSS][SS][S]"; + List datePatterns = + Arrays.asList("yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd HH:mm:ss." + multiplePrecisionPattern); + for (String pattern : datePatterns) { + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(pattern); + LocalDateTime dateTime = LocalDateTime.parse(expiry, formatter); + return dateTime.atZone(ZoneId.systemDefault()).toInstant(); + } catch (DateTimeParseException e) { + lastException = e; + } + } + throw lastException; } private String getProcessStream(InputStream stream) throws IOException { From 7fccff9481b0b7c3c10135e128a33471c8fffa78 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 11 Jun 2025 09:50:31 +0000 Subject: [PATCH 26/49] Update exception handling --- .../java/com/databricks/sdk/core/CliTokenSource.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 9313a4b3f..5b380a84b 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -48,11 +48,11 @@ public CliTokenSource( * @throws DateTimeParseException if the input string cannot be parsed */ static Instant parseExpiry(String expiry) { - DateTimeParseException lastException = null; + DateTimeParseException parseException; try { return OffsetDateTime.parse(expiry).toInstant(); } catch (DateTimeParseException e) { - lastException = e; + parseException = e; } String multiplePrecisionPattern = @@ -65,10 +65,11 @@ static Instant parseExpiry(String expiry) { LocalDateTime dateTime = LocalDateTime.parse(expiry, formatter); return dateTime.atZone(ZoneId.systemDefault()).toInstant(); } catch (DateTimeParseException e) { - lastException = e; + parseException.addSuppressed(e); } } - throw lastException; + + throw parseException; } private String getProcessStream(InputStream stream) throws IOException { From 64313f862d91811e305c145c54d20a40bdd667bc Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 11 Jun 2025 13:19:36 +0000 Subject: [PATCH 27/49] Update Javadoc --- .../com/databricks/sdk/core/CliTokenSource.java | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 5b380a84b..ff0740caa 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -40,12 +40,19 @@ public CliTokenSource( } /** - * Parses an expiry string in RFC 3339/ISO 8601 format (with or without offset) and returns the - * corresponding {@link Instant}. Any specified time zone or offset is converted to UTC. + * Parses an expiry time string and returns the corresponding {@link Instant}. The method attempts + * to parse the input in the following order: 1. RFC 3339/ISO 8601 format with offset (e.g. + * "2024-03-20T10:30:00Z") 2. Local date-time format "yyyy-MM-dd HH:mm:ss" (e.g. "2024-03-20 + * 10:30:00") 3. Local date-time format with optional fractional seconds of varying precision + * (e.g. "2024-03-20 10:30:00.123") * - * @param expiry expiry time string in RFC 3339/ISO 8601 format + *

Any specified time zone or offset is converted to UTC. For local date-time formats, the + * system's default time zone is used. + * + * @param expiry expiry time string in one of the supported formats * @return the parsed {@link Instant} - * @throws DateTimeParseException if the input string cannot be parsed + * @throws DateTimeParseException if the input string cannot be parsed in any of the supported + * formats */ static Instant parseExpiry(String expiry) { DateTimeParseException parseException; From 447eae227c0b1369346d658c2039b654e59dfd17 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 11 Jun 2025 23:38:47 +0000 Subject: [PATCH 28/49] Added more tests to CilTokenSourceTest --- .../sdk/core/CliTokenSourceTest.java | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 20e2f6095..f1d230532 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -1,12 +1,104 @@ package com.databricks.sdk.core; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.utils.Environment; +import com.databricks.sdk.core.utils.OSUtils; +import com.databricks.sdk.core.utils.OSUtilities; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.time.Duration; import java.time.LocalDateTime; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; public class CliTokenSourceTest { + private static final String[] DATE_FORMATS = { + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSS", + "yyyy-MM-dd'T'HH:mm:ss.SSSXXX", + "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + }; + + String getExpiryStr(String dateFormat, Duration offset) { + ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); + return futureExpiry.format(DateTimeFormatter.ofPattern(dateFormat)); + } + + private static Stream provideTestCases() { + return Stream.of( + Arguments.of("Valid: 30min remaining", 30, false), + Arguments.of("Valid: 1hr remaining", 60, false), + Arguments.of("Valid: 2hrs remaining", 120, false), + Arguments.of("Expired: 30min ago", -30, true), + Arguments.of("Expired: 1hr ago", -60, true), + Arguments.of("Expired: 2hrs ago", -120, true) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTestCases") + public void testRefreshWithExpiry(String testName, int offsetMinutes, boolean shouldBeExpired) throws IOException, InterruptedException { + for (String dateFormat : DATE_FORMATS) { + // Mock environment + Environment env = mock(Environment.class); + Map envMap = new HashMap<>(); + when(env.getEnv()).thenReturn(envMap); + + // Create test command + List cmd = Arrays.asList("test", "command"); + + // Mock OSUtilities + OSUtilities osUtils = mock(OSUtilities.class); + when(osUtils.getCliExecutableCommand(any())).thenReturn(cmd); + + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + + // Create token source + CliTokenSource tokenSource = new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); + + String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(offsetMinutes)); + + // Mock process + Process process = mock(Process.class); + when(process.getInputStream()).thenReturn(new ByteArrayInputStream( + String.format("{\"token_type\": \"Bearer\", \"access_token\": \"test-token\", \"expiry\": \"%s\"}", expiryStr).getBytes())); + when(process.getErrorStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(process.waitFor()).thenReturn(0); + + // Mock ProcessBuilder constructor + try (MockedConstruction mocked = mockConstruction(ProcessBuilder.class, + (mock, context) -> { + when(mock.start()).thenReturn(process); + })) { + // Test refresh + Token token = tokenSource.refresh(); + assertEquals("Bearer", token.getTokenType()); + assertEquals("test-token", token.getAccessToken()); + assertEquals(shouldBeExpired, token.isExpired()); + } + } + } + } @Test public void testParseExpiryWithoutTruncate() { From 66335a7098af438045f088d1fbd48613502f6abf Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 12:02:48 +0000 Subject: [PATCH 29/49] Add test to verify perserved behaviour --- .../sdk/core/CliTokenSourceTest.java | 104 ++++++++++++------ 1 file changed, 73 insertions(+), 31 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index f1d230532..a781159b4 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -9,8 +9,8 @@ import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.utils.Environment; -import com.databricks.sdk.core.utils.OSUtils; import com.databricks.sdk.core.utils.OSUtilities; +import com.databricks.sdk.core.utils.OSUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.time.Duration; @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.TimeZone; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -32,10 +33,10 @@ public class CliTokenSourceTest { private static final String[] DATE_FORMATS = { - "yyyy-MM-dd HH:mm:ss", - "yyyy-MM-dd HH:mm:ss.SSS", - "yyyy-MM-dd'T'HH:mm:ss.SSSXXX", - "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSS", + "yyyy-MM-dd'T'HH:mm:ss.SSSXXX", + "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" }; String getExpiryStr(String dateFormat, Duration offset) { @@ -43,20 +44,9 @@ String getExpiryStr(String dateFormat, Duration offset) { return futureExpiry.format(DateTimeFormatter.ofPattern(dateFormat)); } - private static Stream provideTestCases() { - return Stream.of( - Arguments.of("Valid: 30min remaining", 30, false), - Arguments.of("Valid: 1hr remaining", 60, false), - Arguments.of("Valid: 2hrs remaining", 120, false), - Arguments.of("Expired: 30min ago", -30, true), - Arguments.of("Expired: 1hr ago", -60, true), - Arguments.of("Expired: 2hrs ago", -120, true) - ); - } - - @ParameterizedTest(name = "{0}") - @MethodSource("provideTestCases") - public void testRefreshWithExpiry(String testName, int offsetMinutes, boolean shouldBeExpired) throws IOException, InterruptedException { + public void testRefreshWithExpiry( + String testName, int minutesUntilExpiry, boolean shouldBeExpired) + throws IOException, InterruptedException { for (String dateFormat : DATE_FORMATS) { // Mock environment Environment env = mock(Environment.class); @@ -65,31 +55,38 @@ public void testRefreshWithExpiry(String testName, int offsetMinutes, boolean sh // Create test command List cmd = Arrays.asList("test", "command"); - + // Mock OSUtilities OSUtilities osUtils = mock(OSUtilities.class); when(osUtils.getCliExecutableCommand(any())).thenReturn(cmd); - + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); - - // Create token source - CliTokenSource tokenSource = new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); - String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(offsetMinutes)); + CliTokenSource tokenSource = + new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); + + String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(minutesUntilExpiry)); // Mock process Process process = mock(Process.class); - when(process.getInputStream()).thenReturn(new ByteArrayInputStream( - String.format("{\"token_type\": \"Bearer\", \"access_token\": \"test-token\", \"expiry\": \"%s\"}", expiryStr).getBytes())); + when(process.getInputStream()) + .thenReturn( + new ByteArrayInputStream( + String.format( + "{\"token_type\": \"Bearer\", \"access_token\": \"test-token\", \"expiry\": \"%s\"}", + expiryStr) + .getBytes())); when(process.getErrorStream()).thenReturn(new ByteArrayInputStream(new byte[0])); when(process.waitFor()).thenReturn(0); // Mock ProcessBuilder constructor - try (MockedConstruction mocked = mockConstruction(ProcessBuilder.class, - (mock, context) -> { - when(mock.start()).thenReturn(process); - })) { + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (mock, context) -> { + when(mock.start()).thenReturn(process); + })) { // Test refresh Token token = tokenSource.refresh(); assertEquals("Bearer", token.getTokenType()); @@ -100,6 +97,51 @@ public void testRefreshWithExpiry(String testName, int offsetMinutes, boolean sh } } + private static Stream provideTimezoneTestCases() { + // Timezones to test + List timezones = Arrays.asList("UTC", "GMT+1", "GMT+8", "GMT-1", "GMT-8"); + + // Time to expiry of tokens (minutes, shouldBeExpired) + List minutesUntilExpiry = + Arrays.asList( + Arguments.of(5, false), // 5 minutes remaining + Arguments.of(30, false), // 30 minutes remaining + Arguments.of(60, false), // 1 hour remaining + Arguments.of(120, false), // 2 hours remaining + Arguments.of(-5, true), // 5 minutes ago + Arguments.of(-30, true), // 30 minutes ago + Arguments.of(-60, true), // 1 hour ago + Arguments.of(-120, true) // 2 hours ago + ); + + // Create cross product of timezones and minutesUntilExpiry cases + return timezones.stream() + .flatMap( + timezone -> + minutesUntilExpiry.stream() + .map( + minutesUntilExpiryCase -> { + Object[] args = minutesUntilExpiryCase.get(); + return Arguments.of(timezone, args[0], args[1]); + })); + } + + @ParameterizedTest(name = "Test in {0} with {1} minutes offset") + @MethodSource("provideTimezoneTestCases") + public void testRefreshWithDifferentTimezone( + String timezone, int minutesUntilExpiry, boolean shouldBeExpired) + throws IOException, InterruptedException { + // Save original timezone + TimeZone originalTimeZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone(timezone)); + testRefreshWithExpiry("Test in " + timezone, minutesUntilExpiry, shouldBeExpired); + } finally { + // Restore original timezone + TimeZone.setDefault(originalTimeZone); + } + } + @Test public void testParseExpiryWithoutTruncate() { LocalDateTime parsedDateTime = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612218Z"); From 3a824ebaf85a8f8483ff034cc78ff579f6e2339b Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 12:48:28 +0000 Subject: [PATCH 30/49] Generate all timezones --- .../com/databricks/sdk/core/CliTokenSourceTest.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index a781159b4..33c2194e8 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Map; import java.util.TimeZone; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -68,7 +70,7 @@ public void testRefreshWithExpiry( String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(minutesUntilExpiry)); - // Mock process + // Mock process to return the specified expiry string Process process = mock(Process.class); when(process.getInputStream()) .thenReturn( @@ -98,8 +100,11 @@ public void testRefreshWithExpiry( } private static Stream provideTimezoneTestCases() { - // Timezones to test - List timezones = Arrays.asList("UTC", "GMT+1", "GMT+8", "GMT-1", "GMT-8"); + // Generate timezones from GMT-12 to GMT+12 + List timezones = + IntStream.rangeClosed(-12, 12) + .mapToObj(offset -> offset == 0 ? "GMT" : String.format("GMT%+d", offset)) + .collect(Collectors.toList()); // Time to expiry of tokens (minutes, shouldBeExpired) List minutesUntilExpiry = From 4d31c1ea5ce38a6730b3c40918ae9c1412c28346 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 13:49:57 +0000 Subject: [PATCH 31/49] Merge branch 'emmyzhou-db/test_time' into emmyzhou-db/localdatetime-to-instant --- .../sdk/core/CliTokenSourceTest.java | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index d6c8fc740..082682482 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -2,18 +2,155 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.utils.Environment; +import com.databricks.sdk.core.utils.OSUtilities; +import com.databricks.sdk.core.utils.OSUtils; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.time.Duration; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; public class CliTokenSourceTest { + private static final String[] DATE_FORMATS = { + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSS", + TimeZone.getDefault().getID().equals("UTC") + ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" + }; + + String getExpiryStr(String dateFormat, Duration offset) { + ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); + return futureExpiry.format(DateTimeFormatter.ofPattern(dateFormat)); + } + + public void testRefreshWithExpiry( + String testName, int minutesUntilExpiry, boolean shouldBeExpired) + throws IOException, InterruptedException { + for (String dateFormat : DATE_FORMATS) { + // Mock environment + Environment env = mock(Environment.class); + Map envMap = new HashMap<>(); + when(env.getEnv()).thenReturn(envMap); + + // Create test command + List cmd = Arrays.asList("test", "command"); + + // Mock OSUtilities + OSUtilities osUtils = mock(OSUtilities.class); + when(osUtils.getCliExecutableCommand(any())).thenReturn(cmd); + + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + + CliTokenSource tokenSource = + new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); + + String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(minutesUntilExpiry)); + + // Mock process to return the specified expiry string + Process process = mock(Process.class); + when(process.getInputStream()) + .thenReturn( + new ByteArrayInputStream( + String.format( + "{\"token_type\": \"Bearer\", \"access_token\": \"test-token\", \"expiry\": \"%s\"}", + expiryStr) + .getBytes())); + when(process.getErrorStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(process.waitFor()).thenReturn(0); + + // Mock ProcessBuilder constructor + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (mock, context) -> { + when(mock.start()).thenReturn(process); + })) { + // Test refresh + Token token = tokenSource.refresh(); + assertEquals("Bearer", token.getTokenType()); + assertEquals("test-token", token.getAccessToken()); + assertEquals(shouldBeExpired, token.isExpired()); + } + } + } + } + + private static Stream provideTimezoneTestCases() { + // Generate timezones from GMT-12 to GMT+12 + List timezones = + IntStream.rangeClosed(-12, 12) + .mapToObj(offset -> offset == 0 ? "GMT" : String.format("GMT%+d", offset)) + .collect(Collectors.toList()); + + // Time to expiry of tokens (minutes, shouldBeExpired) + List minutesUntilExpiry = + Arrays.asList( + Arguments.of(5, false), // 5 minutes remaining + Arguments.of(30, false), // 30 minutes remaining + Arguments.of(60, false), // 1 hour remaining + Arguments.of(120, false), // 2 hours remaining + Arguments.of(-5, true), // 5 minutes ago + Arguments.of(-30, true), // 30 minutes ago + Arguments.of(-60, true), // 1 hour ago + Arguments.of(-120, true) // 2 hours ago + ); + + // Create cross product of timezones and minutesUntilExpiry cases + return timezones.stream() + .flatMap( + timezone -> + minutesUntilExpiry.stream() + .map( + minutesUntilExpiryCase -> { + Object[] args = minutesUntilExpiryCase.get(); + return Arguments.of(timezone, args[0], args[1]); + })); + } + + @ParameterizedTest(name = "Test in {0} with {1} minutes offset") + @MethodSource("provideTimezoneTestCases") + public void testRefreshWithDifferentTimezone( + String timezone, int minutesUntilExpiry, boolean shouldBeExpired) + throws IOException, InterruptedException { + // Save original timezone + TimeZone originalTimeZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone(timezone)); + testRefreshWithExpiry("Test in " + timezone, minutesUntilExpiry, shouldBeExpired); + } finally { + // Restore original timezone + TimeZone.setDefault(originalTimeZone); + } + } + private static Stream expiryProvider() { return Stream.of( Arguments.of( From 95a3c6d5cd1c7ba5493a57a6acc3ce373b52962e Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 13:52:24 +0000 Subject: [PATCH 32/49] Update test --- .../java/com/databricks/sdk/core/CliTokenSourceTest.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 33c2194e8..54e762c91 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -37,8 +37,9 @@ public class CliTokenSourceTest { private static final String[] DATE_FORMATS = { "yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd HH:mm:ss.SSS", - "yyyy-MM-dd'T'HH:mm:ss.SSSXXX", - "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + TimeZone.getDefault().getID().equals("UTC") + ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" }; String getExpiryStr(String dateFormat, Duration offset) { From a5c65c5813062be22af809729fec848c5ed62e12 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 14:24:04 +0000 Subject: [PATCH 33/49] update tests --- .../test/java/com/databricks/sdk/core/CliTokenSourceTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 082682482..451eafdfe 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -36,7 +36,6 @@ import org.mockito.MockedStatic; public class CliTokenSourceTest { - private static final String[] DATE_FORMATS = { "yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd HH:mm:ss.SSS", From a165b72ec9dd7dd009f96cac8ceb08e9f99bdc77 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 14:35:43 +0000 Subject: [PATCH 34/49] Date formats are generated at run-time --- .../sdk/core/CliTokenSourceTest.java | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 451eafdfe..178390c9c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -36,13 +36,15 @@ import org.mockito.MockedStatic; public class CliTokenSourceTest { - private static final String[] DATE_FORMATS = { - "yyyy-MM-dd HH:mm:ss", - "yyyy-MM-dd HH:mm:ss.SSS", - TimeZone.getDefault().getID().equals("UTC") - ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" - : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" - }; + private String[] getDateFormats() { + return new String[] { + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSS", + TimeZone.getDefault().getID().equals("UTC") + ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" + }; + } String getExpiryStr(String dateFormat, Duration offset) { ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); @@ -52,7 +54,7 @@ String getExpiryStr(String dateFormat, Duration offset) { public void testRefreshWithExpiry( String testName, int minutesUntilExpiry, boolean shouldBeExpired) throws IOException, InterruptedException { - for (String dateFormat : DATE_FORMATS) { + for (String dateFormat : getDateFormats()) { // Mock environment Environment env = mock(Environment.class); Map envMap = new HashMap<>(); From d1c1a6c791cedb2d338fbfeb4a82d909488a9847 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 12 Jun 2025 15:15:37 +0000 Subject: [PATCH 35/49] Generate date formats at run-time --- .../sdk/core/CliTokenSourceTest.java | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 54e762c91..a7202210b 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -34,13 +34,15 @@ import org.mockito.MockedStatic; public class CliTokenSourceTest { - private static final String[] DATE_FORMATS = { - "yyyy-MM-dd HH:mm:ss", - "yyyy-MM-dd HH:mm:ss.SSS", - TimeZone.getDefault().getID().equals("UTC") - ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" - : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" - }; + private String[] getDateFormats() { + return new String[] { + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSS", + TimeZone.getDefault().getID().equals("UTC") + ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" + : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" + }; + } String getExpiryStr(String dateFormat, Duration offset) { ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); @@ -50,7 +52,7 @@ String getExpiryStr(String dateFormat, Duration offset) { public void testRefreshWithExpiry( String testName, int minutesUntilExpiry, boolean shouldBeExpired) throws IOException, InterruptedException { - for (String dateFormat : DATE_FORMATS) { + for (String dateFormat : getDateFormats()) { // Mock environment Environment env = mock(Environment.class); Map envMap = new HashMap<>(); From 785c4009bb1362bfe4534cad7b6836a9f9b19ed7 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 13 Jun 2025 09:13:03 +0000 Subject: [PATCH 36/49] Update stream of test cases --- .../sdk/core/CliTokenSourceTest.java | 125 +++++++++--------- 1 file changed, 65 insertions(+), 60 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index a7202210b..0614d11c7 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -18,6 +18,7 @@ import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -34,25 +35,77 @@ import org.mockito.MockedStatic; public class CliTokenSourceTest { - private String[] getDateFormats() { - return new String[] { - "yyyy-MM-dd HH:mm:ss", - "yyyy-MM-dd HH:mm:ss.SSS", - TimeZone.getDefault().getID().equals("UTC") - ? "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'" - : "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" - }; - } - String getExpiryStr(String dateFormat, Duration offset) { ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); return futureExpiry.format(DateTimeFormatter.ofPattern(dateFormat)); } + private static Stream provideTimezoneTestCases() { + // Generate timezones from GMT-12 to GMT+12 + List timezones = + IntStream.rangeClosed(-12, 12) + .mapToObj(offset -> offset == 0 ? "GMT" : String.format("GMT%+d", offset)) + .collect(Collectors.toList()); + + // Time to expiry of tokens (minutes, shouldBeExpired) + List minutesUntilExpiry = + Arrays.asList( + Arguments.of(5, false), // 5 minutes remaining + Arguments.of(30, false), // 30 minutes remaining + Arguments.of(60, false), // 1 hour remaining + Arguments.of(120, false), // 2 hours remaining + Arguments.of(-5, true), // 5 minutes ago + Arguments.of(-30, true), // 30 minutes ago + Arguments.of(-60, true), // 1 hour ago + Arguments.of(-120, true) // 2 hours ago + ); + + // Create cross product of timezones and minutesUntilExpiry case and match the timezone with the + // date formats + return timezones.stream() + .flatMap( + timezone -> { + List dateFormats = + new ArrayList<>( + Arrays.asList( + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.SSS", + "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); + + if (timezone.equals("GMT")) { + dateFormats.add("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); + } + + return minutesUntilExpiry.stream() + .map( + minutesUntilExpiryCase -> { + Object[] args = minutesUntilExpiryCase.get(); + return Arguments.of(timezone, args[0], args[1], dateFormats); + }); + }); + } + + @ParameterizedTest(name = "Test in {0} with {1} minutes offset") + @MethodSource("provideTimezoneTestCases") + public void testRefreshWithDifferentTimezone( + String timezone, int minutesUntilExpiry, boolean shouldBeExpired, List dateFormats) + throws IOException, InterruptedException { + // Save original timezone + TimeZone originalTimeZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone(timezone)); + testRefreshWithExpiry( + "Test in " + timezone, minutesUntilExpiry, shouldBeExpired, dateFormats); + } finally { + // Restore original timezone + TimeZone.setDefault(originalTimeZone); + } + } + public void testRefreshWithExpiry( - String testName, int minutesUntilExpiry, boolean shouldBeExpired) + String testName, int minutesUntilExpiry, boolean shouldBeExpired, List dateFormats) throws IOException, InterruptedException { - for (String dateFormat : getDateFormats()) { + for (String dateFormat : dateFormats) { // Mock environment Environment env = mock(Environment.class); Map envMap = new HashMap<>(); @@ -102,54 +155,6 @@ public void testRefreshWithExpiry( } } - private static Stream provideTimezoneTestCases() { - // Generate timezones from GMT-12 to GMT+12 - List timezones = - IntStream.rangeClosed(-12, 12) - .mapToObj(offset -> offset == 0 ? "GMT" : String.format("GMT%+d", offset)) - .collect(Collectors.toList()); - - // Time to expiry of tokens (minutes, shouldBeExpired) - List minutesUntilExpiry = - Arrays.asList( - Arguments.of(5, false), // 5 minutes remaining - Arguments.of(30, false), // 30 minutes remaining - Arguments.of(60, false), // 1 hour remaining - Arguments.of(120, false), // 2 hours remaining - Arguments.of(-5, true), // 5 minutes ago - Arguments.of(-30, true), // 30 minutes ago - Arguments.of(-60, true), // 1 hour ago - Arguments.of(-120, true) // 2 hours ago - ); - - // Create cross product of timezones and minutesUntilExpiry cases - return timezones.stream() - .flatMap( - timezone -> - minutesUntilExpiry.stream() - .map( - minutesUntilExpiryCase -> { - Object[] args = minutesUntilExpiryCase.get(); - return Arguments.of(timezone, args[0], args[1]); - })); - } - - @ParameterizedTest(name = "Test in {0} with {1} minutes offset") - @MethodSource("provideTimezoneTestCases") - public void testRefreshWithDifferentTimezone( - String timezone, int minutesUntilExpiry, boolean shouldBeExpired) - throws IOException, InterruptedException { - // Save original timezone - TimeZone originalTimeZone = TimeZone.getDefault(); - try { - TimeZone.setDefault(TimeZone.getTimeZone(timezone)); - testRefreshWithExpiry("Test in " + timezone, minutesUntilExpiry, shouldBeExpired); - } finally { - // Restore original timezone - TimeZone.setDefault(originalTimeZone); - } - } - @Test public void testParseExpiryWithoutTruncate() { LocalDateTime parsedDateTime = CliTokenSource.parseExpiry("2023-07-17T09:02:22.330612218Z"); From 9d5c85e220fb0e7c7ac26decb32eb259746be1b9 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 13 Jun 2025 11:21:12 +0000 Subject: [PATCH 37/49] Add comment --- .../test/java/com/databricks/sdk/core/CliTokenSourceTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index f7364c306..29229dc5c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -75,6 +75,7 @@ private static Stream provideTimezoneTestCases() { "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); if (timezone.equals("GMT")) { + // We only test with this format when timezone is GMT dateFormats.add("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); } From 2a7cd9544662d2b9cd8e12e02c4e8d6d3c506bc8 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 13 Jun 2025 11:22:35 +0000 Subject: [PATCH 38/49] Add comment --- .../test/java/com/databricks/sdk/core/CliTokenSourceTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 0614d11c7..2e06939e2 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -73,6 +73,7 @@ private static Stream provideTimezoneTestCases() { "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); if (timezone.equals("GMT")) { + // We only test with this format when timezone is GMT dateFormats.add("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); } From 257998c649fa181ffa83f38d7f02dc1076d37506 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 13 Jun 2025 12:23:08 +0000 Subject: [PATCH 39/49] Update comment --- .../java/com/databricks/sdk/core/CliTokenSourceTest.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 2e06939e2..b8108dc02 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -73,7 +73,11 @@ private static Stream provideTimezoneTestCases() { "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); if (timezone.equals("GMT")) { - // We only test with this format when timezone is GMT + /* + * The Databricks CLI outputs timestamps with 'Z' suffix (e.g., + * 2024-03-14T10:30:00.000Z) only when in UTC/GMT+0 timezone. + * Thus, we only test with this format together with the GMT timezone. + */ dateFormats.add("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); } From ceea58a329db26e705221dd43016778aa3f07fab Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 13 Jun 2025 12:42:29 +0000 Subject: [PATCH 40/49] update stream of tests --- .../sdk/core/CliTokenSourceTest.java | 117 +++++++++--------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 89b82c2e3..63940b968 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -62,8 +62,7 @@ private static Stream provideTimezoneTestCases() { Arguments.of(-120, true) // 2 hours ago ); - // Create cross product of timezones and minutesUntilExpiry case and match the timezone with the - // date formats + // Create cross product of timezones and minutesUntilExpiry case return timezones.stream() .flatMap( timezone -> { @@ -83,26 +82,28 @@ private static Stream provideTimezoneTestCases() { dateFormats.add("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); } - return minutesUntilExpiry.stream() - .map( - minutesUntilExpiryCase -> { - Object[] args = minutesUntilExpiryCase.get(); - return Arguments.of(timezone, args[0], args[1], dateFormats); - }); + return dateFormats.stream() + .flatMap( + dateFormat -> + minutesUntilExpiry.stream() + .map( + minutesUntilExpiryCase -> { + Object[] args = minutesUntilExpiryCase.get(); + return Arguments.of(timezone, args[0], args[1], dateFormat); + })); }); } - @ParameterizedTest(name = "Test in {0} with {1} minutes offset") + @ParameterizedTest(name = "Test in {0} with {1} minutes offset using format {3}") @MethodSource("provideTimezoneTestCases") public void testRefreshWithDifferentTimezone( - String timezone, int minutesUntilExpiry, boolean shouldBeExpired, List dateFormats) + String timezone, int minutesUntilExpiry, boolean shouldBeExpired, String dateFormat) throws IOException, InterruptedException { // Save original timezone TimeZone originalTimeZone = TimeZone.getDefault(); try { TimeZone.setDefault(TimeZone.getTimeZone(timezone)); - testRefreshWithExpiry( - "Test in " + timezone, minutesUntilExpiry, shouldBeExpired, dateFormats); + testRefreshWithExpiry("Test in " + timezone, minutesUntilExpiry, shouldBeExpired, dateFormat); } finally { // Restore original timezone TimeZone.setDefault(originalTimeZone); @@ -110,54 +111,52 @@ public void testRefreshWithDifferentTimezone( } public void testRefreshWithExpiry( - String testName, int minutesUntilExpiry, boolean shouldBeExpired, List dateFormats) + String testName, int minutesUntilExpiry, boolean shouldBeExpired, String dateFormat) throws IOException, InterruptedException { - for (String dateFormat : dateFormats) { - // Mock environment - Environment env = mock(Environment.class); - Map envMap = new HashMap<>(); - when(env.getEnv()).thenReturn(envMap); - - // Create test command - List cmd = Arrays.asList("test", "command"); - - // Mock OSUtilities - OSUtilities osUtils = mock(OSUtilities.class); - when(osUtils.getCliExecutableCommand(any())).thenReturn(cmd); - - try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { - mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); - - CliTokenSource tokenSource = - new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); - - String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(minutesUntilExpiry)); - - // Mock process to return the specified expiry string - Process process = mock(Process.class); - when(process.getInputStream()) - .thenReturn( - new ByteArrayInputStream( - String.format( - "{\"token_type\": \"Bearer\", \"access_token\": \"test-token\", \"expiry\": \"%s\"}", - expiryStr) - .getBytes())); - when(process.getErrorStream()).thenReturn(new ByteArrayInputStream(new byte[0])); - when(process.waitFor()).thenReturn(0); - - // Mock ProcessBuilder constructor - try (MockedConstruction mocked = - mockConstruction( - ProcessBuilder.class, - (mock, context) -> { - when(mock.start()).thenReturn(process); - })) { - // Test refresh - Token token = tokenSource.refresh(); - assertEquals("Bearer", token.getTokenType()); - assertEquals("test-token", token.getAccessToken()); - assertEquals(shouldBeExpired, token.isExpired()); - } + // Mock environment + Environment env = mock(Environment.class); + Map envMap = new HashMap<>(); + when(env.getEnv()).thenReturn(envMap); + + // Create test command + List cmd = Arrays.asList("test", "command"); + + // Mock OSUtilities + OSUtilities osUtils = mock(OSUtilities.class); + when(osUtils.getCliExecutableCommand(any())).thenReturn(cmd); + + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + + CliTokenSource tokenSource = + new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); + + String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(minutesUntilExpiry)); + + // Mock process to return the specified expiry string + Process process = mock(Process.class); + when(process.getInputStream()) + .thenReturn( + new ByteArrayInputStream( + String.format( + "{\"token_type\": \"Bearer\", \"access_token\": \"test-token\", \"expiry\": \"%s\"}", + expiryStr) + .getBytes())); + when(process.getErrorStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(process.waitFor()).thenReturn(0); + + // Mock ProcessBuilder constructor + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (mock, context) -> { + when(mock.start()).thenReturn(process); + })) { + // Test refresh + Token token = tokenSource.refresh(); + assertEquals("Bearer", token.getTokenType()); + assertEquals("test-token", token.getAccessToken()); + assertEquals(shouldBeExpired, token.isExpired()); } } } From 0697faef8af01a647cd82e4e9cedb243b4beb3d2 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 13 Jun 2025 14:09:23 +0000 Subject: [PATCH 41/49] Polish comments --- .../sdk/core/CliTokenSourceTest.java | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 63940b968..aa98730cb 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -43,13 +43,13 @@ String getExpiryStr(String dateFormat, Duration offset) { } private static Stream provideTimezoneTestCases() { - // Generate timezones from GMT-12 to GMT+12 + // Generate timezones from GMT-12 to GMT+12. List timezones = IntStream.rangeClosed(-12, 12) .mapToObj(offset -> offset == 0 ? "GMT" : String.format("GMT%+d", offset)) .collect(Collectors.toList()); - // Time to expiry of tokens (minutes, shouldBeExpired) + // Time to expiry of tokens (minutes, shouldBeExpired). List minutesUntilExpiry = Arrays.asList( Arguments.of(5, false), // 5 minutes remaining @@ -62,7 +62,7 @@ private static Stream provideTimezoneTestCases() { Arguments.of(-120, true) // 2 hours ago ); - // Create cross product of timezones and minutesUntilExpiry case + // Create cross product of timezones and minutesUntilExpiry cases. return timezones.stream() .flatMap( timezone -> { @@ -74,11 +74,9 @@ private static Stream provideTimezoneTestCases() { "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); if (timezone.equals("GMT")) { - /* - * The Databricks CLI outputs timestamps with 'Z' suffix (e.g., - * 2024-03-14T10:30:00.000Z) only when in UTC/GMT+0 timezone. - * Thus, we only test with this format together with the GMT timezone. - */ + // The Databricks CLI outputs timestamps with 'Z' suffix (e.g., + // 2024-03-14T10:30:00.000Z) only when in UTC/GMT+0 timezone. + // Thus, we only test with this format together with the GMT timezone. dateFormats.add("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); } @@ -99,13 +97,13 @@ private static Stream provideTimezoneTestCases() { public void testRefreshWithDifferentTimezone( String timezone, int minutesUntilExpiry, boolean shouldBeExpired, String dateFormat) throws IOException, InterruptedException { - // Save original timezone + // Save original timezone. TimeZone originalTimeZone = TimeZone.getDefault(); try { TimeZone.setDefault(TimeZone.getTimeZone(timezone)); testRefreshWithExpiry("Test in " + timezone, minutesUntilExpiry, shouldBeExpired, dateFormat); } finally { - // Restore original timezone + // Restore original timezone. TimeZone.setDefault(originalTimeZone); } } @@ -113,15 +111,15 @@ public void testRefreshWithDifferentTimezone( public void testRefreshWithExpiry( String testName, int minutesUntilExpiry, boolean shouldBeExpired, String dateFormat) throws IOException, InterruptedException { - // Mock environment + // Mock environment. Environment env = mock(Environment.class); Map envMap = new HashMap<>(); when(env.getEnv()).thenReturn(envMap); - // Create test command + // Create test command. List cmd = Arrays.asList("test", "command"); - // Mock OSUtilities + // Mock OSUtilities. OSUtilities osUtils = mock(OSUtilities.class); when(osUtils.getCliExecutableCommand(any())).thenReturn(cmd); @@ -133,7 +131,7 @@ public void testRefreshWithExpiry( String expiryStr = getExpiryStr(dateFormat, Duration.ofMinutes(minutesUntilExpiry)); - // Mock process to return the specified expiry string + // Mock process to return the specified expiry string. Process process = mock(Process.class); when(process.getInputStream()) .thenReturn( @@ -145,14 +143,14 @@ public void testRefreshWithExpiry( when(process.getErrorStream()).thenReturn(new ByteArrayInputStream(new byte[0])); when(process.waitFor()).thenReturn(0); - // Mock ProcessBuilder constructor + // Mock ProcessBuilder constructor. try (MockedConstruction mocked = mockConstruction( ProcessBuilder.class, (mock, context) -> { when(mock.start()).thenReturn(process); })) { - // Test refresh + // Test refresh. Token token = tokenSource.refresh(); assertEquals("Bearer", token.getTokenType()); assertEquals("test-token", token.getAccessToken()); From 490091c47ff06552ebb53a39a190e6859e34a298 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 16 Jun 2025 12:22:30 +0000 Subject: [PATCH 42/49] Small fixes --- .../com/databricks/sdk/core/oauth/RefreshableTokenSource.java | 2 +- .../src/main/java/com/databricks/sdk/core/oauth/Token.java | 2 +- .../com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java | 1 + .../java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java | 1 + 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 2d98fdc37..d421764fc 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -213,7 +213,7 @@ protected Token getTokenAsync() { * Trigger an asynchronous refresh of the token if not already in progress and last refresh * succeeded. */ - protected synchronized void triggerAsyncRefresh() { + private synchronized void triggerAsyncRefresh() { // Check token state to avoid triggering a refresh if another thread has already refreshed it if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { refreshInProgress = true; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index 18ef99880..4a3b42a7e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -70,7 +70,7 @@ public String getAccessToken() { } /** - * Returns the expiry time of the token as a LocalDateTime. + * Returns the expiry time of the token as a Instant. * * @return the expiry time */ diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java index f3eb2ed94..bac591212 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -194,6 +194,7 @@ void testDataPlaneTokenSource( assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); assertEquals(expectedToken.getTokenType(), token.getTokenType()); assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); + assertTrue(expectedToken.getExpiry().isAfter(Instant.now())); } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java index d08ddf583..710fbf1b8 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java @@ -54,6 +54,7 @@ void testSaveAndLoadToken() { assertEquals("access-token", loadedToken.getAccessToken()); assertEquals("Bearer", loadedToken.getTokenType()); assertEquals("refresh-token", loadedToken.getRefreshToken()); + assertEquals(expiry, loadedToken.getExpiry()); } @Test From 1b88b29e22b60f8aeab50e42c311a76a18eb840c Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 16 Jun 2025 12:39:22 +0000 Subject: [PATCH 43/49] More small fixes --- .../com/databricks/sdk/core/oauth/RefreshableTokenSource.java | 4 ++-- .../test/java/com/databricks/sdk/core/oauth/TokenTest.java | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index d421764fc..caf10e0db 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -214,7 +214,7 @@ protected Token getTokenAsync() { * succeeded. */ private synchronized void triggerAsyncRefresh() { - // Check token state to avoid triggering a refresh if another thread has already refreshed it + // Check token state again to avoid triggering a refresh if another thread updated the token if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { refreshInProgress = true; CompletableFuture.runAsync( @@ -230,7 +230,7 @@ private synchronized void triggerAsyncRefresh() { synchronized (this) { lastRefreshSucceeded = false; refreshInProgress = false; - logger.error("Async token refresh failed", e); + logger.error("Asynchronous token refresh failed", e); } } }); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java index 23c1a418f..a0173cbf8 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenTest.java @@ -18,6 +18,7 @@ void createNonRefreshableToken() { assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertNull(token.getRefreshToken()); + assertEquals(currentInstant.plusSeconds(300), token.getExpiry()); } @Test @@ -26,5 +27,6 @@ void createRefreshableToken() { assertEquals(accessToken, token.getAccessToken()); assertEquals(tokenType, token.getTokenType()); assertEquals(refreshToken, token.getRefreshToken()); + assertEquals(currentInstant.plusSeconds(300), token.getExpiry()); } } From c07669be8cf97372d2e5aa887c3b65865e8ae62f Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 16 Jun 2025 13:16:05 +0000 Subject: [PATCH 44/49] Update comment --- .../com/databricks/sdk/core/oauth/RefreshableTokenSource.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index caf10e0db..596116ef6 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -214,7 +214,8 @@ protected Token getTokenAsync() { * succeeded. */ private synchronized void triggerAsyncRefresh() { - // Check token state again to avoid triggering a refresh if another thread updated the token + // Check token state again inside the synchronized block to avoid triggering a refresh if + // another thread updated the token in the meantime. if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) { refreshInProgress = true; CompletableFuture.runAsync( From d8db181112e329b7e76b53ac940d8a0f38c65322 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 16 Jun 2025 15:02:23 +0000 Subject: [PATCH 45/49] Small fix --- .../databricks/sdk/core/oauth/RefreshableTokenSource.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 596116ef6..9b91b9f01 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -39,6 +39,8 @@ protected enum TokenState { private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); // Default duration before expiry to consider a token as 'stale'. private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); + // Default additional buffer before expiry to consider a token as expired. + private static final Duration DEFAULT_EXPIRY_BUFFER = Duration.ofSeconds(40); // The current OAuth token. May be null if not yet fetched. protected volatile Token token; @@ -47,12 +49,12 @@ protected enum TokenState { // Duration before expiry to consider a token as 'stale'. private Duration staleDuration = DEFAULT_STALE_DURATION; // Additional buffer before expiry to consider a token as expired. - private Duration expiryBuffer = Duration.ofSeconds(40); + private Duration expiryBuffer = DEFAULT_EXPIRY_BUFFER; // Whether a refresh is currently in progress (for async refresh). private boolean refreshInProgress = false; // Whether the last refresh attempt succeeded. private boolean lastRefreshSucceeded = true; - // Clock supplier for current time, for testing purposes. + // Clock supplier for current time, can be overridden for testing purposes. private ClockSupplier clockSupplier = new SystemClockSupplier(); /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ From 87800e60a61d59365e87bb725b33ba136c6736a9 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 16 Jun 2025 15:45:24 +0000 Subject: [PATCH 46/49] Rename SystemClockSupplier --- .../databricks/sdk/core/oauth/RefreshableTokenSource.java | 6 +++--- .../{SystemClockSupplier.java => UtcClockSupplier.java} | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) rename databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/{SystemClockSupplier.java => UtcClockSupplier.java} (70%) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 9b91b9f01..342f1d9ae 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -6,7 +6,7 @@ import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; import com.databricks.sdk.core.utils.ClockSupplier; -import com.databricks.sdk.core.utils.SystemClockSupplier; +import com.databricks.sdk.core.utils.UtcClockSupplier; import java.time.Duration; import java.time.Instant; import java.util.Base64; @@ -54,8 +54,8 @@ protected enum TokenState { private boolean refreshInProgress = false; // Whether the last refresh attempt succeeded. private boolean lastRefreshSucceeded = true; - // Clock supplier for current time, can be overridden for testing purposes. - private ClockSupplier clockSupplier = new SystemClockSupplier(); + // Clock supplier for current time. + private ClockSupplier clockSupplier = new UtcClockSupplier(); /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ public RefreshableTokenSource() {} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/UtcClockSupplier.java similarity index 70% rename from databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java rename to databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/UtcClockSupplier.java index c79e6884d..73e6d2bc4 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/UtcClockSupplier.java @@ -2,7 +2,7 @@ import java.time.Clock; -public class SystemClockSupplier implements ClockSupplier { +public class UtcClockSupplier implements ClockSupplier { @Override public Clock getClock() { return Clock.systemUTC(); From 99ba2b14f90045c0103720798893ce76bb00b436 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 16 Jun 2025 16:05:29 +0000 Subject: [PATCH 47/49] Update test to use TestClockSupplier instead --- .../oauth/RefreshableTokenSourceTest.java | 27 +++++++++++++------ .../sdk/core/utils/FakeClockSupplier.java | 18 ------------- .../sdk/core/utils/TestClockSupplier.java | 23 ++++++++++++++++ 3 files changed, 42 insertions(+), 26 deletions(-) delete mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/FakeClockSupplier.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/TestClockSupplier.java diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java index 34ad00026..194c3a2ec 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.*; +import com.databricks.sdk.core.utils.TestClockSupplier; import java.time.Duration; import java.time.Instant; import java.util.concurrent.CountDownLatch; @@ -80,8 +81,16 @@ protected Token refresh() { */ @Test void testAsyncRefreshFailureFallback() throws Exception { + // Create a test clock starting at current time + TestClockSupplier clockSupplier = new TestClockSupplier(Instant.now()); + + // Create a token that expires in 2 minutes from the initial clock time Token staleToken = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, Instant.now().plus(Duration.ofMinutes(2))); + new Token( + INITIAL_TOKEN, + TOKEN_TYPE, + null, + Instant.now(clockSupplier.getClock()).plus(Duration.ofMinutes(2))); class TestSource extends RefreshableTokenSource { int refreshCallCount = 0; @@ -99,12 +108,16 @@ protected Token refresh() { throw new RuntimeException("Simulated async failure"); } return new Token( - REFRESH_TOKEN, TOKEN_TYPE, null, Instant.now().plus(Duration.ofMinutes(10))); + REFRESH_TOKEN, + TOKEN_TYPE, + null, + Instant.now(clockSupplier.getClock()).plus(Duration.ofMinutes(10))); } } TestSource source = new TestSource(staleToken); source.withAsyncRefresh(true); + source.withClockSupplier(clockSupplier); // First call triggers async refresh, which fails source.getToken(); @@ -121,9 +134,8 @@ protected Token refresh() { source.refreshCallCount, "refresh() should NOT be called again while stale after async failure"); - // Advance the clock so the token is now expired - source.token = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, Instant.now().minus(Duration.ofMinutes(1))); + // Advance the clock by 3 minutes to make the token expired + clockSupplier.advanceTime(Duration.ofMinutes(3)); // Now getToken() should call refresh synchronously and return the refreshed token Token token = source.getToken(); @@ -134,9 +146,8 @@ protected Token refresh() { assertEquals( 2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); - // Make the token stale again and trigger async refresh since the last sync refresh succeeded - source.token = - new Token(INITIAL_TOKEN, TOKEN_TYPE, null, Instant.now().plus(Duration.ofMinutes(2))); + // Advance time by 8 minutes to make the token stale again + clockSupplier.advanceTime(Duration.ofMinutes(8)); source.getToken(); Thread.sleep(300); assertEquals( diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/FakeClockSupplier.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/FakeClockSupplier.java deleted file mode 100644 index 78f8df076..000000000 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/FakeClockSupplier.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.databricks.sdk.core.utils; - -import java.time.Clock; -import java.time.Instant; -import java.time.ZoneId; - -public class FakeClockSupplier implements ClockSupplier { - private final Clock clock; - - public FakeClockSupplier(Instant fixedInstant, ZoneId zoneId) { - clock = Clock.fixed(fixedInstant, zoneId); - } - - @Override - public Clock getClock() { - return clock; - } -} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/TestClockSupplier.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/TestClockSupplier.java new file mode 100644 index 000000000..fb97149dc --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/TestClockSupplier.java @@ -0,0 +1,23 @@ +package com.databricks.sdk.core.utils; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; + +public class TestClockSupplier implements ClockSupplier { + private Clock clock; + + public TestClockSupplier(Instant fixedInstant) { + clock = Clock.fixed(fixedInstant, ZoneId.of("UTC")); + } + + public void advanceTime(Duration duration) { + clock = Clock.offset(clock, duration); + } + + @Override + public Clock getClock() { + return clock; + } +} From 536a283198877d14e33107bee773a390835ba236 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 17 Jun 2025 08:57:10 +0000 Subject: [PATCH 48/49] Small refactor of getToken --- .../databricks/sdk/core/oauth/RefreshableTokenSource.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 342f1d9ae..d84f6de96 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -129,10 +129,10 @@ public RefreshableTokenSource withExpiryBuffer(Duration buffer) { * @return The current valid token */ public Token getToken() { - if (!asyncEnabled) { - return getTokenBlocking(); + if (asyncEnabled) { + return getTokenAsync(); } - return getTokenAsync(); + return getTokenBlocking(); } /** From f6f0dd0f2a26f95ac267a27ad3f6aa27f2c7b4ba Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 17 Jun 2025 16:02:48 +0000 Subject: [PATCH 49/49] Trigger CI: empty commit