diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java index 750aae967..d84f6de96 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java @@ -5,29 +5,245 @@ import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.utils.ClockSupplier; +import com.databricks.sdk.core.utils.UtcClockSupplier; +import java.time.Duration; import java.time.Instant; import java.util.Base64; import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.apache.http.HttpHeaders; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * An OAuth TokenSource which can be refreshed. * - *
Calls to getToken() will first check if the token is still valid (currently defined by having - * at least 10 seconds until expiry). If not, refresh() is called first to refresh the token. + *
This class supports both synchronous and asynchronous token refresh. When async is enabled, + * stale tokens will trigger a background refresh, while expired tokens will block until a new token + * is fetched. */ public abstract class RefreshableTokenSource implements TokenSource { - protected Token token; + /** + * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: + * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: + * Token has expired and must be refreshed using a blocking call. + */ + protected enum TokenState { + FRESH, + STALE, + EXPIRED + } + + private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); + // Default duration before expiry to consider a token as 'stale'. + private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); + // Default additional buffer before expiry to consider a token as expired. + private static final Duration DEFAULT_EXPIRY_BUFFER = Duration.ofSeconds(40); + + // The current OAuth token. May be null if not yet fetched. + protected volatile Token token; + // Whether asynchronous refresh is enabled. + private boolean asyncEnabled = false; + // Duration before expiry to consider a token as 'stale'. + private Duration staleDuration = DEFAULT_STALE_DURATION; + // Additional buffer before expiry to consider a token as expired. + private Duration expiryBuffer = DEFAULT_EXPIRY_BUFFER; + // Whether a refresh is currently in progress (for async refresh). + private boolean refreshInProgress = false; + // Whether the last refresh attempt succeeded. + private boolean lastRefreshSucceeded = true; + // Clock supplier for current time. + private ClockSupplier clockSupplier = new UtcClockSupplier(); + + /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ public RefreshableTokenSource() {} + /** + * Constructor with initial token. + * + * @param token The initial token to use. + */ public RefreshableTokenSource(Token token) { this.token = token; } + /** + * Set the clock supplier for current time. + * + *
Experimental: This method may change or be removed in future releases. + * + * @param clockSupplier The clock supplier to use. + * @return this instance for chaining + */ + public RefreshableTokenSource withClockSupplier(ClockSupplier clockSupplier) { + this.clockSupplier = clockSupplier; + return this; + } + + /** + * Enable or disable asynchronous token refresh. + * + *
Experimental: This method may change or be removed in future releases. + * + * @param enabled true to enable async refresh, false to disable + * @return this instance for chaining + */ + public RefreshableTokenSource withAsyncRefresh(boolean enabled) { + this.asyncEnabled = enabled; + return this; + } + + /** + * Set the expiry buffer. If the token's lifetime is less than this buffer, it is considered + * expired. + * + *
Experimental: This method may change or be removed in future releases. + * + * @param buffer the expiry buffer duration + * @return this instance for chaining + */ + public RefreshableTokenSource withExpiryBuffer(Duration buffer) { + this.expiryBuffer = buffer; + return this; + } + + /** + * Refresh the OAuth token. Subclasses must implement this to define how the token is refreshed. + * + *
This method may throw an exception if the token cannot be refreshed. The specific exception + * type depends on the implementation. + * + * @return The newly refreshed Token. + */ + protected abstract Token refresh(); + + /** + * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a + * stale token while a refresh is in progress. + * + *
This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid token + */ + public Token getToken() { + if (asyncEnabled) { + return getTokenAsync(); + } + return getTokenBlocking(); + } + + /** + * Determine the state of the current token (fresh, stale, or expired). + * + * @return The token state + */ + protected TokenState getTokenState(Token t) { + if (t == null) { + return TokenState.EXPIRED; + } + Duration lifeTime = Duration.between(Instant.now(clockSupplier.getClock()), t.getExpiry()); + if (lifeTime.compareTo(expiryBuffer) <= 0) { + return TokenState.EXPIRED; + } + if (lifeTime.compareTo(staleDuration) <= 0) { + return TokenState.STALE; + } + return TokenState.FRESH; + } + + /** + * Get the current token, blocking to refresh if expired. + * + *
This method may throw an exception if the token cannot be refreshed, depending on the + * implementation of {@link #refresh()}. + * + * @return The current valid token + */ + protected Token getTokenBlocking() { + // Use double-checked locking to minimize synchronization overhead on reads: + // 1. Check if the token is expired without locking. + // 2. If expired, synchronize and check again (another thread may have refreshed it). + // 3. If still expired, perform the refresh. + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + synchronized (this) { + if (getTokenState(token) != TokenState.EXPIRED) { + return token; + } + lastRefreshSucceeded = false; + try { + token = refresh(); + } catch (Exception e) { + logger.error("Failed to refresh token synchronously", e); + throw e; + } + lastRefreshSucceeded = true; + return token; + } + } + + /** + * Get the current token, possibly triggering an async refresh if stale. If the token is expired, + * blocks to refresh. + * + *
This method may throw an exception if the token cannot be refreshed, depending on the
+ * implementation of {@link #refresh()}.
+ *
+ * @return The current valid or stale token
+ */
+ protected Token getTokenAsync() {
+ Token currentToken = token;
+
+ switch (getTokenState(currentToken)) {
+ case FRESH:
+ return currentToken;
+ case STALE:
+ triggerAsyncRefresh();
+ return currentToken;
+ case EXPIRED:
+ return getTokenBlocking();
+ default:
+ throw new IllegalStateException("Invalid token state.");
+ }
+ }
+
+ /**
+ * Trigger an asynchronous refresh of the token if not already in progress and last refresh
+ * succeeded.
+ */
+ private synchronized void triggerAsyncRefresh() {
+ // Check token state again inside the synchronized block to avoid triggering a refresh if
+ // another thread updated the token in the meantime.
+ if (!refreshInProgress && lastRefreshSucceeded && getTokenState(token) != TokenState.FRESH) {
+ refreshInProgress = true;
+ CompletableFuture.runAsync(
+ () -> {
+ try {
+ // Attempt to refresh the token in the background
+ Token newToken = refresh();
+ synchronized (this) {
+ token = newToken;
+ refreshInProgress = false;
+ }
+ } catch (Exception e) {
+ synchronized (this) {
+ lastRefreshSucceeded = false;
+ refreshInProgress = false;
+ logger.error("Asynchronous token refresh failed", e);
+ }
+ }
+ });
+ }
+ }
+
/**
* Helper method implementing OAuth token refresh.
*
+ * @param hc The HTTP client to use for the request.
* @param clientId The client ID to authenticate with.
* @param clientSecret The client secret to authenticate with.
* @param tokenUrl The authorization URL for fetching tokens.
@@ -35,6 +251,8 @@ public RefreshableTokenSource(Token token) {
* @param headers Additional headers.
* @param position The position of the authentication parameters in the request.
* @return The newly fetched Token.
+ * @throws DatabricksException if the refresh fails
+ * @throws IllegalArgumentException if the OAuth response contains an error
*/
protected static Token retrieveToken(
HttpClient hc,
@@ -75,13 +293,4 @@ protected static Token retrieveToken(
throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e);
}
}
-
- protected abstract Token refresh();
-
- public synchronized Token getToken() {
- if (token == null || !token.isValid()) {
- token = refresh();
- }
- return token;
- }
}
diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java
index ac6fbc3ac..4a3b42a7e 100644
--- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java
+++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java
@@ -1,7 +1,5 @@
package com.databricks.sdk.core.oauth;
-import com.databricks.sdk.core.utils.ClockSupplier;
-import com.databricks.sdk.core.utils.SystemClockSupplier;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.time.Instant;
@@ -23,16 +21,9 @@ public class Token {
*/
@JsonProperty private Instant expiry;
- private final ClockSupplier clockSupplier;
-
/** Constructor for non-refreshable tokens (e.g. M2M). */
public Token(String accessToken, String tokenType, Instant expiry) {
- this(accessToken, tokenType, null, expiry, new SystemClockSupplier());
- }
-
- /** Constructor for non-refreshable tokens (e.g. M2M) with ClockSupplier */
- public Token(String accessToken, String tokenType, Instant expiry, ClockSupplier clockSupplier) {
- this(accessToken, tokenType, null, expiry, clockSupplier);
+ this(accessToken, tokenType, null, expiry);
}
/** Constructor for refreshable tokens. */
@@ -42,51 +33,48 @@ public Token(
@JsonProperty("tokenType") String tokenType,
@JsonProperty("refreshToken") String refreshToken,
@JsonProperty("expiry") Instant expiry) {
- this(accessToken, tokenType, refreshToken, expiry, new SystemClockSupplier());
- }
-
- /** Constructor for refreshable tokens with ClockSupplier. */
- public Token(
- String accessToken,
- String tokenType,
- String refreshToken,
- Instant expiry,
- ClockSupplier clockSupplier) {
Objects.requireNonNull(accessToken, "accessToken must be defined");
Objects.requireNonNull(tokenType, "tokenType must be defined");
Objects.requireNonNull(expiry, "expiry must be defined");
- Objects.requireNonNull(clockSupplier, "clockSupplier must be defined");
this.accessToken = accessToken;
this.tokenType = tokenType;
this.refreshToken = refreshToken;
this.expiry = expiry;
- this.clockSupplier = clockSupplier;
- }
-
- public boolean isExpired() {
- if (expiry == null) {
- return false;
- }
- // Azure Databricks rejects tokens that expire in 30 seconds or less,
- // so we refresh the token 40 seconds before it expires.
- Instant potentiallyExpired = expiry.minusSeconds(40);
- Instant now = Instant.now(clockSupplier.getClock());
- return potentiallyExpired.isBefore(now);
- }
-
- public boolean isValid() {
- return accessToken != null && !isExpired();
}
+ /**
+ * Returns the type of the token (e.g., "Bearer").
+ *
+ * @return the token type
+ */
public String getTokenType() {
return tokenType;
}
+ /**
+ * Returns the refresh token, if available. May be null for non-refreshable tokens.
+ *
+ * @return the refresh token or null
+ */
public String getRefreshToken() {
return refreshToken;
}
+ /**
+ * Returns the access token string.
+ *
+ * @return the access token
+ */
public String getAccessToken() {
return accessToken;
}
+
+ /**
+ * Returns the expiry time of the token as a Instant.
+ *
+ * @return the expiry time
+ */
+ public Instant getExpiry() {
+ return this.expiry;
+ }
}
diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/UtcClockSupplier.java
similarity index 70%
rename from databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java
rename to databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/UtcClockSupplier.java
index c79e6884d..73e6d2bc4 100644
--- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SystemClockSupplier.java
+++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/UtcClockSupplier.java
@@ -2,7 +2,7 @@
import java.time.Clock;
-public class SystemClockSupplier implements ClockSupplier {
+public class UtcClockSupplier implements ClockSupplier {
@Override
public Clock getClock() {
return Clock.systemUTC();
diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java
index aa98730cb..05ffd805d 100644
--- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java
+++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java
@@ -154,7 +154,7 @@ public void testRefreshWithExpiry(
Token token = tokenSource.refresh();
assertEquals("Bearer", token.getTokenType());
assertEquals("test-token", token.getAccessToken());
- assertEquals(shouldBeExpired, token.isExpired());
+ assertEquals(shouldBeExpired, token.getExpiry().isBefore(Instant.now()));
}
}
}
diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java
index 1fb96a559..bac591212 100644
--- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java
+++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java
@@ -194,7 +194,7 @@ void testDataPlaneTokenSource(
assertEquals(expectedToken.getAccessToken(), token.getAccessToken());
assertEquals(expectedToken.getTokenType(), token.getTokenType());
assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken());
- assertTrue(token.isValid());
+ assertTrue(expectedToken.getExpiry().isAfter(Instant.now()));
}
}
diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java
index 8217179f2..ee226cd42 100644
--- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java
+++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java
@@ -330,7 +330,6 @@ void testTokenSource(TestCase testCase) {
assertEquals(TOKEN, token.getAccessToken());
assertEquals(TOKEN_TYPE, token.getTokenType());
assertEquals(REFRESH_TOKEN, token.getRefreshToken());
- assertFalse(token.isExpired());
// Verify correct audience was used
verify(testCase.idTokenSource, atLeastOnce()).getIDToken(testCase.expectedAudience);
diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java
index 303f6de66..710fbf1b8 100644
--- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java
+++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java
@@ -54,24 +54,7 @@ void testSaveAndLoadToken() {
assertEquals("access-token", loadedToken.getAccessToken());
assertEquals("Bearer", loadedToken.getTokenType());
assertEquals("refresh-token", loadedToken.getRefreshToken());
- assertFalse(loadedToken.isExpired(), "Token should not be expired");
- }
-
- @Test
- void testTokenExpiry() {
- // Create an expired token
- Instant pastTime = Instant.now().minusSeconds(3600);
- Token expiredToken = new Token("access-token", "Bearer", "refresh-token", pastTime);
-
- // Verify it's marked as expired
- assertTrue(expiredToken.isExpired(), "Token should be expired");
-
- // Create a valid token
- Instant futureTime = Instant.now().plusSeconds(1800);
- Token validToken = new Token("access-token", "Bearer", "refresh-token", futureTime);
-
- // Verify it's not marked as expired
- assertFalse(validToken.isExpired(), "Token should not be expired");
+ assertEquals(expiry, loadedToken.getExpiry());
}
@Test
diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java
new file mode 100644
index 000000000..194c3a2ec
--- /dev/null
+++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java
@@ -0,0 +1,158 @@
+package com.databricks.sdk.core.oauth;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import com.databricks.sdk.core.utils.TestClockSupplier;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Stream;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+public class RefreshableTokenSourceTest {
+ private static final String TOKEN_TYPE = "Bearer";
+ private static final String INITIAL_TOKEN = "initial-token";
+ private static final String REFRESH_TOKEN = "refreshed-token";
+ private static final long FRESH_MINUTES = 10;
+ private static final long STALE_MINUTES = 1;
+ private static final long EXPIRED_MINUTES = -1;
+
+ private static Stream