diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 00496463..25609efe 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -19,6 +19,14 @@ import HiveDriverError from './errors/HiveDriverError'; import { buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth'; +import { + TokenProviderAuthenticator, + StaticTokenProvider, + ExternalTokenProvider, + CachedTokenProvider, + FederationProvider, + ITokenProvider, +} from './connection/auth/tokenProvider'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; import CloseableCollection from './utils/CloseableCollection'; @@ -143,10 +151,63 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I }); case 'custom': return options.provider; + case 'token-provider': + return new TokenProviderAuthenticator( + this.wrapTokenProvider( + options.tokenProvider, + options.host, + options.enableTokenFederation, + options.federationClientId, + ), + this, + ); + case 'external-token': + return new TokenProviderAuthenticator( + this.wrapTokenProvider( + new ExternalTokenProvider(options.getToken), + options.host, + options.enableTokenFederation, + options.federationClientId, + ), + this, + ); + case 'static-token': + return new TokenProviderAuthenticator( + this.wrapTokenProvider( + StaticTokenProvider.fromJWT(options.staticToken), + options.host, + options.enableTokenFederation, + options.federationClientId, + ), + this, + ); // no default } } + /** + * Wraps a token provider with caching and optional federation. + * Caching is always enabled by default. Federation is opt-in. + */ + private wrapTokenProvider( + provider: ITokenProvider, + host: string, + enableFederation?: boolean, + federationClientId?: string, + ): ITokenProvider { + // Always wrap with caching first + let wrapped: ITokenProvider = new CachedTokenProvider(provider); + + // Optionally wrap with federation + if (enableFederation) { + wrapped = new FederationProvider(wrapped, host, { + clientId: federationClientId, + }); + } + + return wrapped; + } + private createConnectionProvider(options: ConnectionOptions): IConnectionProvider { return new HttpConnection(this.getConnectionOptions(options), this); } diff --git a/lib/connection/auth/tokenProvider/CachedTokenProvider.ts b/lib/connection/auth/tokenProvider/CachedTokenProvider.ts new file mode 100644 index 00000000..7172ea0b --- /dev/null +++ b/lib/connection/auth/tokenProvider/CachedTokenProvider.ts @@ -0,0 +1,98 @@ +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; + +/** + * Default refresh threshold in milliseconds (5 minutes). + * Tokens will be refreshed when they are within this threshold of expiring. + */ +const DEFAULT_REFRESH_THRESHOLD_MS = 5 * 60 * 1000; + +/** + * A token provider that wraps another provider with automatic caching. + * Tokens are cached and reused until they are close to expiring. + */ +export default class CachedTokenProvider implements ITokenProvider { + private readonly baseProvider: ITokenProvider; + + private readonly refreshThresholdMs: number; + + private cache: Token | null = null; + + private refreshPromise: Promise | null = null; + + /** + * Creates a new CachedTokenProvider. + * @param baseProvider - The underlying token provider to cache + * @param options - Optional configuration + * @param options.refreshThresholdMs - Refresh tokens this many ms before expiry (default: 5 minutes) + */ + constructor( + baseProvider: ITokenProvider, + options?: { + refreshThresholdMs?: number; + }, + ) { + this.baseProvider = baseProvider; + this.refreshThresholdMs = options?.refreshThresholdMs ?? DEFAULT_REFRESH_THRESHOLD_MS; + } + + async getToken(): Promise { + // Return cached token if it's still valid + if (this.cache && !this.shouldRefresh(this.cache)) { + return this.cache; + } + + // If already refreshing, wait for that to complete + if (this.refreshPromise) { + return this.refreshPromise; + } + + // Start refresh + this.refreshPromise = this.refreshToken(); + + try { + const token = await this.refreshPromise; + return token; + } finally { + this.refreshPromise = null; + } + } + + getName(): string { + return `cached[${this.baseProvider.getName()}]`; + } + + /** + * Clears the cached token, forcing a refresh on the next getToken() call. + */ + clearCache(): void { + this.cache = null; + } + + /** + * Determines if the token should be refreshed. + * @param token - The token to check + * @returns true if the token should be refreshed + */ + private shouldRefresh(token: Token): boolean { + // If no expiration is known, don't refresh proactively + if (!token.expiresAt) { + return false; + } + + const now = Date.now(); + const expiresAtMs = token.expiresAt.getTime(); + const refreshAtMs = expiresAtMs - this.refreshThresholdMs; + + return now >= refreshAtMs; + } + + /** + * Fetches a new token from the base provider and caches it. + */ + private async refreshToken(): Promise { + const token = await this.baseProvider.getToken(); + this.cache = token; + return token; + } +} diff --git a/lib/connection/auth/tokenProvider/ExternalTokenProvider.ts b/lib/connection/auth/tokenProvider/ExternalTokenProvider.ts new file mode 100644 index 00000000..ada48038 --- /dev/null +++ b/lib/connection/auth/tokenProvider/ExternalTokenProvider.ts @@ -0,0 +1,52 @@ +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; + +/** + * Type for the callback function that retrieves tokens from external sources. + */ +export type TokenCallback = () => Promise; + +/** + * A token provider that delegates token retrieval to an external callback function. + * Useful for integrating with secret managers, vaults, or other token sources. + */ +export default class ExternalTokenProvider implements ITokenProvider { + private readonly getTokenCallback: TokenCallback; + + private readonly parseJWT: boolean; + + private readonly providerName: string; + + /** + * Creates a new ExternalTokenProvider. + * @param getToken - Callback function that returns the access token string + * @param options - Optional configuration + * @param options.parseJWT - If true, attempt to extract expiration from JWT payload (default: true) + * @param options.name - Custom name for this provider (default: "ExternalTokenProvider") + */ + constructor( + getToken: TokenCallback, + options?: { + parseJWT?: boolean; + name?: string; + }, + ) { + this.getTokenCallback = getToken; + this.parseJWT = options?.parseJWT ?? true; + this.providerName = options?.name ?? 'ExternalTokenProvider'; + } + + async getToken(): Promise { + const accessToken = await this.getTokenCallback(); + + if (this.parseJWT) { + return Token.fromJWT(accessToken); + } + + return new Token(accessToken); + } + + getName(): string { + return this.providerName; + } +} diff --git a/lib/connection/auth/tokenProvider/FederationProvider.ts b/lib/connection/auth/tokenProvider/FederationProvider.ts new file mode 100644 index 00000000..e95b415e --- /dev/null +++ b/lib/connection/auth/tokenProvider/FederationProvider.ts @@ -0,0 +1,192 @@ +import fetch from 'node-fetch'; +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; +import { getJWTIssuer, isSameHost } from './utils'; + +/** + * Token exchange endpoint path for Databricks OIDC. + */ +const TOKEN_EXCHANGE_ENDPOINT = '/oidc/v1/token'; + +/** + * Grant type for RFC 8693 token exchange. + */ +const TOKEN_EXCHANGE_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange'; + +/** + * Subject token type for JWT tokens. + */ +const SUBJECT_TOKEN_TYPE = 'urn:ietf:params:oauth:token-type:jwt'; + +/** + * Default scope for SQL operations. + */ +const DEFAULT_SCOPE = 'sql'; + +/** + * Timeout for token exchange requests in milliseconds. + */ +const REQUEST_TIMEOUT_MS = 30000; + +/** + * A token provider that wraps another provider with automatic token federation. + * When the base provider returns a token from a different issuer, this provider + * exchanges it for a Databricks-compatible token using RFC 8693. + */ +export default class FederationProvider implements ITokenProvider { + private readonly baseProvider: ITokenProvider; + + private readonly databricksHost: string; + + private readonly clientId?: string; + + private readonly returnOriginalTokenOnFailure: boolean; + + /** + * Creates a new FederationProvider. + * @param baseProvider - The underlying token provider + * @param databricksHost - The Databricks workspace host URL + * @param options - Optional configuration + * @param options.clientId - Client ID for M2M/service principal federation + * @param options.returnOriginalTokenOnFailure - Return original token if exchange fails (default: true) + */ + constructor( + baseProvider: ITokenProvider, + databricksHost: string, + options?: { + clientId?: string; + returnOriginalTokenOnFailure?: boolean; + }, + ) { + this.baseProvider = baseProvider; + this.databricksHost = databricksHost; + this.clientId = options?.clientId; + this.returnOriginalTokenOnFailure = options?.returnOriginalTokenOnFailure ?? true; + } + + async getToken(): Promise { + const token = await this.baseProvider.getToken(); + + // Check if token needs exchange + if (!this.needsTokenExchange(token)) { + return token; + } + + // Attempt token exchange + try { + return await this.exchangeToken(token); + } catch (error) { + if (this.returnOriginalTokenOnFailure) { + // Fall back to original token + return token; + } + throw error; + } + } + + getName(): string { + return `federated[${this.baseProvider.getName()}]`; + } + + /** + * Determines if the token needs to be exchanged. + * @param token - The token to check + * @returns true if the token should be exchanged + */ + private needsTokenExchange(token: Token): boolean { + const issuer = getJWTIssuer(token.accessToken); + + // If we can't extract the issuer, don't exchange (might not be a JWT) + if (!issuer) { + return false; + } + + // If the issuer is the same as Databricks host, no exchange needed + if (isSameHost(issuer, this.databricksHost)) { + return false; + } + + return true; + } + + /** + * Exchanges the token for a Databricks-compatible token using RFC 8693. + * @param token - The token to exchange + * @returns The exchanged token + */ + private async exchangeToken(token: Token): Promise { + const url = this.buildExchangeUrl(); + + const params = new URLSearchParams({ + grant_type: TOKEN_EXCHANGE_GRANT_TYPE, + subject_token_type: SUBJECT_TOKEN_TYPE, + subject_token: token.accessToken, + scope: DEFAULT_SCOPE, + }); + + if (this.clientId) { + params.append('client_id', this.clientId); + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), REQUEST_TIMEOUT_MS); + + try { + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: params.toString(), + signal: controller.signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Token exchange failed: ${response.status} ${response.statusText} - ${errorText}`); + } + + const data = (await response.json()) as { + access_token?: string; + token_type?: string; + expires_in?: number; + }; + + if (!data.access_token) { + throw new Error('Token exchange response missing access_token'); + } + + // Calculate expiration from expires_in + let expiresAt: Date | undefined; + if (typeof data.expires_in === 'number') { + expiresAt = new Date(Date.now() + data.expires_in * 1000); + } + + return new Token(data.access_token, { + tokenType: data.token_type ?? 'Bearer', + expiresAt, + }); + } finally { + clearTimeout(timeoutId); + } + } + + /** + * Builds the token exchange URL. + */ + private buildExchangeUrl(): string { + let host = this.databricksHost; + + // Ensure host has a protocol + if (!host.includes('://')) { + host = `https://${host}`; + } + + // Remove trailing slash + if (host.endsWith('/')) { + host = host.slice(0, -1); + } + + return `${host}${TOKEN_EXCHANGE_ENDPOINT}`; + } +} diff --git a/lib/connection/auth/tokenProvider/ITokenProvider.ts b/lib/connection/auth/tokenProvider/ITokenProvider.ts new file mode 100644 index 00000000..a7cd23dc --- /dev/null +++ b/lib/connection/auth/tokenProvider/ITokenProvider.ts @@ -0,0 +1,19 @@ +import Token from './Token'; + +/** + * Interface for token providers that supply access tokens for authentication. + * Token providers can be wrapped with caching and federation decorators. + */ +export default interface ITokenProvider { + /** + * Retrieves an access token for authentication. + * @returns A Promise that resolves to a Token object containing the access token + */ + getToken(): Promise; + + /** + * Returns the name of this token provider for logging and debugging purposes. + * @returns The provider name + */ + getName(): string; +} diff --git a/lib/connection/auth/tokenProvider/StaticTokenProvider.ts b/lib/connection/auth/tokenProvider/StaticTokenProvider.ts new file mode 100644 index 00000000..0a4acead --- /dev/null +++ b/lib/connection/auth/tokenProvider/StaticTokenProvider.ts @@ -0,0 +1,58 @@ +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; + +/** + * A token provider that returns a static token. + * Useful for testing or when the token is obtained through external means. + */ +export default class StaticTokenProvider implements ITokenProvider { + private readonly token: Token; + + /** + * Creates a new StaticTokenProvider. + * @param accessToken - The access token string + * @param options - Optional token configuration (tokenType, expiresAt, refreshToken, scopes) + */ + constructor( + accessToken: string, + options?: { + tokenType?: string; + expiresAt?: Date; + refreshToken?: string; + scopes?: string[]; + }, + ) { + this.token = new Token(accessToken, options); + } + + /** + * Creates a StaticTokenProvider from a JWT string. + * The expiration time will be extracted from the JWT payload. + * @param jwt - The JWT token string + * @param options - Optional token configuration + */ + static fromJWT( + jwt: string, + options?: { + tokenType?: string; + refreshToken?: string; + scopes?: string[]; + }, + ): StaticTokenProvider { + const token = Token.fromJWT(jwt, options); + return new StaticTokenProvider(token.accessToken, { + tokenType: token.tokenType, + expiresAt: token.expiresAt, + refreshToken: token.refreshToken, + scopes: token.scopes, + }); + } + + async getToken(): Promise { + return this.token; + } + + getName(): string { + return 'StaticTokenProvider'; + } +} diff --git a/lib/connection/auth/tokenProvider/Token.ts b/lib/connection/auth/tokenProvider/Token.ts new file mode 100644 index 00000000..911b2bdd --- /dev/null +++ b/lib/connection/auth/tokenProvider/Token.ts @@ -0,0 +1,151 @@ +import { HeadersInit } from 'node-fetch'; + +/** + * Safety buffer in seconds to consider a token expired before its actual expiration time. + * This prevents using tokens that are about to expire during in-flight requests. + */ +const EXPIRATION_BUFFER_SECONDS = 30; + +/** + * Represents an access token with optional metadata and lifecycle management. + */ +export default class Token { + private readonly _accessToken: string; + + private readonly _tokenType: string; + + private readonly _expiresAt?: Date; + + private readonly _refreshToken?: string; + + private readonly _scopes?: string[]; + + constructor( + accessToken: string, + options?: { + tokenType?: string; + expiresAt?: Date; + refreshToken?: string; + scopes?: string[]; + }, + ) { + this._accessToken = accessToken; + this._tokenType = options?.tokenType ?? 'Bearer'; + this._expiresAt = options?.expiresAt; + this._refreshToken = options?.refreshToken; + this._scopes = options?.scopes; + } + + /** + * The access token string. + */ + get accessToken(): string { + return this._accessToken; + } + + /** + * The token type (e.g., "Bearer"). + */ + get tokenType(): string { + return this._tokenType; + } + + /** + * The expiration time of the token, if known. + */ + get expiresAt(): Date | undefined { + return this._expiresAt; + } + + /** + * The refresh token, if available. + */ + get refreshToken(): string | undefined { + return this._refreshToken; + } + + /** + * The scopes associated with this token. + */ + get scopes(): string[] | undefined { + return this._scopes; + } + + /** + * Checks if the token has expired, including a safety buffer. + * Returns false if expiration time is unknown. + */ + isExpired(): boolean { + if (!this._expiresAt) { + return false; + } + const now = new Date(); + const bufferMs = EXPIRATION_BUFFER_SECONDS * 1000; + return this._expiresAt.getTime() - bufferMs <= now.getTime(); + } + + /** + * Sets the Authorization header on the provided headers object. + * @param headers - The headers object to modify + * @returns The modified headers object with Authorization set + */ + setAuthHeader(headers: HeadersInit): HeadersInit { + return { + ...headers, + Authorization: `${this._tokenType} ${this._accessToken}`, + }; + } + + /** + * Creates a Token from a JWT string, extracting the expiration time from the payload. + * If the JWT cannot be decoded, the token is created without expiration info. + * The server will validate the token anyway, so decoding failures are handled gracefully. + * @param jwt - The JWT token string + * @param options - Additional token options (tokenType, refreshToken, scopes) + * @returns A new Token instance with expiration extracted from the JWT (if available) + */ + static fromJWT( + jwt: string, + options?: { + tokenType?: string; + refreshToken?: string; + scopes?: string[]; + }, + ): Token { + let expiresAt: Date | undefined; + + try { + const parts = jwt.split('.'); + if (parts.length >= 2) { + const payload = Buffer.from(parts[1], 'base64').toString('utf8'); + const decoded = JSON.parse(payload); + if (typeof decoded.exp === 'number') { + expiresAt = new Date(decoded.exp * 1000); + } + } + } catch { + // If we can't decode the JWT, we'll proceed without expiration info + // The server will validate the token anyway + } + + return new Token(jwt, { + tokenType: options?.tokenType, + expiresAt, + refreshToken: options?.refreshToken, + scopes: options?.scopes, + }); + } + + /** + * Converts the token to a plain object for serialization. + */ + toJSON(): Record { + return { + accessToken: this._accessToken, + tokenType: this._tokenType, + expiresAt: this._expiresAt?.toISOString(), + refreshToken: this._refreshToken, + scopes: this._scopes, + }; + } +} diff --git a/lib/connection/auth/tokenProvider/TokenProviderAuthenticator.ts b/lib/connection/auth/tokenProvider/TokenProviderAuthenticator.ts new file mode 100644 index 00000000..2c77127b --- /dev/null +++ b/lib/connection/auth/tokenProvider/TokenProviderAuthenticator.ts @@ -0,0 +1,44 @@ +import { HeadersInit } from 'node-fetch'; +import IAuthentication from '../../contracts/IAuthentication'; +import ITokenProvider from './ITokenProvider'; +import IClientContext from '../../../contracts/IClientContext'; +import { LogLevel } from '../../../contracts/IDBSQLLogger'; + +/** + * Adapts an ITokenProvider to the IAuthentication interface used by the driver. + * This allows token providers to be used with the existing authentication system. + */ +export default class TokenProviderAuthenticator implements IAuthentication { + private readonly tokenProvider: ITokenProvider; + + private readonly context: IClientContext; + + private readonly headers: HeadersInit; + + /** + * Creates a new TokenProviderAuthenticator. + * @param tokenProvider - The token provider to use for authentication + * @param context - The client context for logging + * @param headers - Additional headers to include with each request + */ + constructor(tokenProvider: ITokenProvider, context: IClientContext, headers?: HeadersInit) { + this.tokenProvider = tokenProvider; + this.context = context; + this.headers = headers ?? {}; + } + + async authenticate(): Promise { + const logger = this.context.getLogger(); + const providerName = this.tokenProvider.getName(); + + logger.log(LogLevel.debug, `TokenProviderAuthenticator: getting token from ${providerName}`); + + const token = await this.tokenProvider.getToken(); + + if (token.isExpired()) { + logger.log(LogLevel.warn, `TokenProviderAuthenticator: token from ${providerName} is expired`); + } + + return token.setAuthHeader(this.headers); + } +} diff --git a/lib/connection/auth/tokenProvider/index.ts b/lib/connection/auth/tokenProvider/index.ts new file mode 100644 index 00000000..e09db00f --- /dev/null +++ b/lib/connection/auth/tokenProvider/index.ts @@ -0,0 +1,8 @@ +export { default as ITokenProvider } from './ITokenProvider'; +export { default as Token } from './Token'; +export { default as StaticTokenProvider } from './StaticTokenProvider'; +export { default as ExternalTokenProvider, TokenCallback } from './ExternalTokenProvider'; +export { default as TokenProviderAuthenticator } from './TokenProviderAuthenticator'; +export { default as CachedTokenProvider } from './CachedTokenProvider'; +export { default as FederationProvider } from './FederationProvider'; +export { decodeJWT, getJWTIssuer, isSameHost } from './utils'; diff --git a/lib/connection/auth/tokenProvider/utils.ts b/lib/connection/auth/tokenProvider/utils.ts new file mode 100644 index 00000000..cc8df0e2 --- /dev/null +++ b/lib/connection/auth/tokenProvider/utils.ts @@ -0,0 +1,79 @@ +/** + * Decodes a JWT token without verifying the signature. + * This is safe because the server will validate the token anyway. + * + * @param token - The JWT token string + * @returns The decoded payload as a record, or null if decoding fails + */ +export function decodeJWT(token: string): Record | null { + try { + const parts = token.split('.'); + if (parts.length < 2) { + return null; + } + const payload = Buffer.from(parts[1], 'base64').toString('utf8'); + return JSON.parse(payload); + } catch { + return null; + } +} + +/** + * Extracts the issuer from a JWT token. + * + * @param token - The JWT token string + * @returns The issuer string, or null if not found + */ +export function getJWTIssuer(token: string): string | null { + const payload = decodeJWT(token); + if (!payload || typeof payload.iss !== 'string') { + return null; + } + return payload.iss; +} + +/** + * Extracts the hostname from a URL or hostname string. + * Handles both full URLs and bare hostnames. + * + * @param urlOrHostname - A URL or hostname string + * @returns The extracted hostname + */ +function extractHostname(urlOrHostname: string): string { + // If it looks like a URL, parse it + if (urlOrHostname.includes('://')) { + const url = new URL(urlOrHostname); + return url.hostname; + } + + // Handle hostname with port (e.g., "databricks.com:443") + const colonIndex = urlOrHostname.indexOf(':'); + if (colonIndex !== -1) { + return urlOrHostname.substring(0, colonIndex); + } + + // Bare hostname + return urlOrHostname; +} + +/** + * Compares two host URLs, ignoring ports. + * Treats "databricks.com" and "databricks.com:443" as equivalent. + * + * @param url1 - First URL or hostname + * @param url2 - Second URL or hostname + * @returns true if the hosts are the same + */ +export function isSameHost(url1: string, url2: string): boolean { + try { + const host1 = extractHostname(url1); + const host2 = extractHostname(url2); + // Empty hostnames are not valid + if (!host1 || !host2) { + return false; + } + return host1.toLowerCase() === host2.toLowerCase(); + } catch { + return false; + } +} diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 26588031..4b2f39a4 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -3,6 +3,8 @@ import IDBSQLSession from './IDBSQLSession'; import IAuthentication from '../connection/contracts/IAuthentication'; import { ProxyOptions } from '../connection/contracts/IConnectionOptions'; import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence'; +import ITokenProvider from '../connection/auth/tokenProvider/ITokenProvider'; +import { TokenCallback } from '../connection/auth/tokenProvider/ExternalTokenProvider'; export interface ClientOptions { logger?: IDBSQLLogger; @@ -24,6 +26,24 @@ type AuthOptions = | { authType: 'custom'; provider: IAuthentication; + } + | { + authType: 'token-provider'; + tokenProvider: ITokenProvider; + enableTokenFederation?: boolean; + federationClientId?: string; + } + | { + authType: 'external-token'; + getToken: TokenCallback; + enableTokenFederation?: boolean; + federationClientId?: string; + } + | { + authType: 'static-token'; + staticToken: string; + enableTokenFederation?: boolean; + federationClientId?: string; }; export type ConnectionOptions = { diff --git a/tests/unit/connection/auth/tokenProvider/CachedTokenProvider.test.ts b/tests/unit/connection/auth/tokenProvider/CachedTokenProvider.test.ts new file mode 100644 index 00000000..5c62a89a --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/CachedTokenProvider.test.ts @@ -0,0 +1,165 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import CachedTokenProvider from '../../../../../lib/connection/auth/tokenProvider/CachedTokenProvider'; +import ITokenProvider from '../../../../../lib/connection/auth/tokenProvider/ITokenProvider'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; + +class MockTokenProvider implements ITokenProvider { + public callCount = 0; + public tokenToReturn: Token; + + constructor(expiresInMs: number = 3600000) { + this.tokenToReturn = new Token(`token-${this.callCount}`, { + expiresAt: new Date(Date.now() + expiresInMs), + }); + } + + async getToken(): Promise { + this.callCount += 1; + this.tokenToReturn = new Token(`token-${this.callCount}`, { + expiresAt: this.tokenToReturn.expiresAt, + }); + return this.tokenToReturn; + } + + getName(): string { + return 'MockTokenProvider'; + } +} + +describe('CachedTokenProvider', () => { + let clock: sinon.SinonFakeTimers; + + beforeEach(() => { + clock = sinon.useFakeTimers(Date.now()); + }); + + afterEach(() => { + clock.restore(); + }); + + describe('getToken', () => { + it('should cache tokens and return the same token on subsequent calls', async () => { + const baseProvider = new MockTokenProvider(3600000); // 1 hour expiry + const cachedProvider = new CachedTokenProvider(baseProvider); + + const token1 = await cachedProvider.getToken(); + const token2 = await cachedProvider.getToken(); + const token3 = await cachedProvider.getToken(); + + expect(token1.accessToken).to.equal(token2.accessToken); + expect(token2.accessToken).to.equal(token3.accessToken); + expect(baseProvider.callCount).to.equal(1); // Only called once + }); + + it('should refresh token when it approaches expiry', async () => { + const expiresInMs = 10 * 60 * 1000; // 10 minutes + const baseProvider = new MockTokenProvider(expiresInMs); + const cachedProvider = new CachedTokenProvider(baseProvider, { + refreshThresholdMs: 5 * 60 * 1000, // 5 minutes threshold + }); + + const token1 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); + + // Advance time to 6 minutes from now (within refresh threshold) + clock.tick(6 * 60 * 1000); + + const token2 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(2); // Should have refreshed + expect(token1.accessToken).to.not.equal(token2.accessToken); + }); + + it('should not refresh token when not within threshold', async () => { + const expiresInMs = 60 * 60 * 1000; // 1 hour + const baseProvider = new MockTokenProvider(expiresInMs); + const cachedProvider = new CachedTokenProvider(baseProvider, { + refreshThresholdMs: 5 * 60 * 1000, // 5 minutes threshold + }); + + await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); + + // Advance time by 10 minutes (still 50 minutes until expiry) + clock.tick(10 * 60 * 1000); + + await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); // Should still use cached + }); + + it('should handle tokens without expiration', async () => { + const baseProvider: ITokenProvider = { + async getToken() { + return new Token('no-expiry-token'); + }, + getName() { + return 'NoExpiryProvider'; + }, + }; + const getTokenSpy = sinon.spy(baseProvider, 'getToken'); + const cachedProvider = new CachedTokenProvider(baseProvider); + + await cachedProvider.getToken(); + await cachedProvider.getToken(); + await cachedProvider.getToken(); + + expect(getTokenSpy.callCount).to.equal(1); // Should cache indefinitely + }); + + it('should handle concurrent getToken calls', async () => { + let resolvePromise: (token: Token) => void; + const slowProvider: ITokenProvider = { + getToken() { + return new Promise((resolve) => { + resolvePromise = resolve; + }); + }, + getName() { + return 'SlowProvider'; + }, + }; + const getTokenSpy = sinon.spy(slowProvider, 'getToken'); + const cachedProvider = new CachedTokenProvider(slowProvider); + + // Start multiple concurrent requests + const promise1 = cachedProvider.getToken(); + const promise2 = cachedProvider.getToken(); + const promise3 = cachedProvider.getToken(); + + // Resolve the single underlying request + resolvePromise!(new Token('concurrent-token')); + + const [token1, token2, token3] = await Promise.all([promise1, promise2, promise3]); + + expect(token1.accessToken).to.equal('concurrent-token'); + expect(token2.accessToken).to.equal('concurrent-token'); + expect(token3.accessToken).to.equal('concurrent-token'); + expect(getTokenSpy.callCount).to.equal(1); // Only one underlying call + }); + }); + + describe('clearCache', () => { + it('should force a refresh on the next getToken call', async () => { + const baseProvider = new MockTokenProvider(3600000); + const cachedProvider = new CachedTokenProvider(baseProvider); + + const token1 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); + + cachedProvider.clearCache(); + + const token2 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(2); + expect(token1.accessToken).to.not.equal(token2.accessToken); + }); + }); + + describe('getName', () => { + it('should return wrapped name', () => { + const baseProvider = new MockTokenProvider(); + const cachedProvider = new CachedTokenProvider(baseProvider); + + expect(cachedProvider.getName()).to.equal('cached[MockTokenProvider]'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/ExternalTokenProvider.test.ts b/tests/unit/connection/auth/tokenProvider/ExternalTokenProvider.test.ts new file mode 100644 index 00000000..6695040d --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/ExternalTokenProvider.test.ts @@ -0,0 +1,108 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import ExternalTokenProvider from '../../../../../lib/connection/auth/tokenProvider/ExternalTokenProvider'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('ExternalTokenProvider', () => { + describe('constructor', () => { + it('should create provider with callback', async () => { + const callback = sinon.stub().resolves('my-token'); + const provider = new ExternalTokenProvider(callback); + + await provider.getToken(); + + expect(callback.calledOnce).to.be.true; + }); + + it('should use default name', () => { + const provider = new ExternalTokenProvider(async () => 'token'); + expect(provider.getName()).to.equal('ExternalTokenProvider'); + }); + + it('should use custom name', () => { + const provider = new ExternalTokenProvider(async () => 'token', { name: 'MyCustomProvider' }); + expect(provider.getName()).to.equal('MyCustomProvider'); + }); + }); + + describe('getToken', () => { + it('should call callback and return token', async () => { + const callback = sinon.stub().resolves('my-access-token'); + const provider = new ExternalTokenProvider(callback); + + const token = await provider.getToken(); + + expect(token.accessToken).to.equal('my-access-token'); + expect(token.tokenType).to.equal('Bearer'); + }); + + it('should extract expiration from JWT by default', async () => { + const exp = Math.floor(Date.now() / 1000) + 3600; + const jwt = createJWT({ exp, iss: 'test-issuer' }); + const callback = sinon.stub().resolves(jwt); + const provider = new ExternalTokenProvider(callback); + + const token = await provider.getToken(); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.instanceOf(Date); + expect(Math.floor(token.expiresAt!.getTime() / 1000)).to.equal(exp); + }); + + it('should not parse JWT when parseJWT is false', async () => { + const jwt = createJWT({ exp: Math.floor(Date.now() / 1000) + 3600 }); + const callback = sinon.stub().resolves(jwt); + const provider = new ExternalTokenProvider(callback, { parseJWT: false }); + + const token = await provider.getToken(); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.undefined; + }); + + it('should call callback on each getToken call', async () => { + let callCount = 0; + const callback = async () => { + callCount += 1; + return `token-${callCount}`; + }; + const provider = new ExternalTokenProvider(callback); + + const token1 = await provider.getToken(); + const token2 = await provider.getToken(); + + expect(token1.accessToken).to.equal('token-1'); + expect(token2.accessToken).to.equal('token-2'); + }); + + it('should propagate errors from callback', async () => { + const error = new Error('Failed to get token'); + const callback = sinon.stub().rejects(error); + const provider = new ExternalTokenProvider(callback); + + try { + await provider.getToken(); + expect.fail('Should have thrown an error'); + } catch (e) { + expect(e).to.equal(error); + } + }); + }); + + describe('getName', () => { + it('should return default name', () => { + const provider = new ExternalTokenProvider(async () => 'token'); + expect(provider.getName()).to.equal('ExternalTokenProvider'); + }); + + it('should return custom name', () => { + const provider = new ExternalTokenProvider(async () => 'token', { name: 'VaultTokenProvider' }); + expect(provider.getName()).to.equal('VaultTokenProvider'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/FederationProvider.test.ts b/tests/unit/connection/auth/tokenProvider/FederationProvider.test.ts new file mode 100644 index 00000000..4a7c5465 --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/FederationProvider.test.ts @@ -0,0 +1,79 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import FederationProvider from '../../../../../lib/connection/auth/tokenProvider/FederationProvider'; +import ITokenProvider from '../../../../../lib/connection/auth/tokenProvider/ITokenProvider'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +class MockTokenProvider implements ITokenProvider { + public tokenToReturn: Token; + + constructor(accessToken: string) { + this.tokenToReturn = new Token(accessToken); + } + + async getToken(): Promise { + return this.tokenToReturn; + } + + getName(): string { + return 'MockTokenProvider'; + } +} + +describe('FederationProvider', () => { + describe('getToken', () => { + it('should pass through token if issuer matches Databricks host', async () => { + const jwt = createJWT({ iss: 'https://my-workspace.cloud.databricks.com' }); + const baseProvider = new MockTokenProvider(jwt); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal(jwt); + }); + + it('should pass through non-JWT tokens', async () => { + const baseProvider = new MockTokenProvider('not-a-jwt-token'); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal('not-a-jwt-token'); + }); + + it('should pass through token when issuer matches (case insensitive)', async () => { + const jwt = createJWT({ iss: 'https://MY-WORKSPACE.CLOUD.DATABRICKS.COM' }); + const baseProvider = new MockTokenProvider(jwt); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal(jwt); + }); + + it('should pass through token when issuer matches (ignoring port)', async () => { + const jwt = createJWT({ iss: 'https://my-workspace.cloud.databricks.com:443' }); + const baseProvider = new MockTokenProvider(jwt); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal(jwt); + }); + }); + + describe('getName', () => { + it('should return wrapped name', () => { + const baseProvider = new MockTokenProvider('token'); + const federationProvider = new FederationProvider(baseProvider, 'host.com'); + + expect(federationProvider.getName()).to.equal('federated[MockTokenProvider]'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/StaticTokenProvider.test.ts b/tests/unit/connection/auth/tokenProvider/StaticTokenProvider.test.ts new file mode 100644 index 00000000..976bf84e --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/StaticTokenProvider.test.ts @@ -0,0 +1,85 @@ +import { expect } from 'chai'; +import StaticTokenProvider from '../../../../../lib/connection/auth/tokenProvider/StaticTokenProvider'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('StaticTokenProvider', () => { + describe('constructor', () => { + it('should create provider with access token only', async () => { + const provider = new StaticTokenProvider('my-access-token'); + const token = await provider.getToken(); + + expect(token.accessToken).to.equal('my-access-token'); + expect(token.tokenType).to.equal('Bearer'); + }); + + it('should create provider with custom options', async () => { + const expiresAt = new Date('2025-01-01T00:00:00Z'); + const provider = new StaticTokenProvider('my-access-token', { + tokenType: 'CustomType', + expiresAt, + refreshToken: 'refresh-token', + scopes: ['read', 'write'], + }); + const token = await provider.getToken(); + + expect(token.accessToken).to.equal('my-access-token'); + expect(token.tokenType).to.equal('CustomType'); + expect(token.expiresAt).to.deep.equal(expiresAt); + expect(token.refreshToken).to.equal('refresh-token'); + expect(token.scopes).to.deep.equal(['read', 'write']); + }); + }); + + describe('fromJWT', () => { + it('should create provider from JWT and extract expiration', async () => { + const exp = Math.floor(Date.now() / 1000) + 3600; + const jwt = createJWT({ exp, iss: 'test-issuer' }); + + const provider = StaticTokenProvider.fromJWT(jwt); + const token = await provider.getToken(); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.instanceOf(Date); + expect(Math.floor(token.expiresAt!.getTime() / 1000)).to.equal(exp); + }); + + it('should create provider from JWT with custom options', async () => { + const jwt = createJWT({ exp: Math.floor(Date.now() / 1000) + 3600 }); + + const provider = StaticTokenProvider.fromJWT(jwt, { + tokenType: 'CustomType', + refreshToken: 'refresh', + scopes: ['sql'], + }); + const token = await provider.getToken(); + + expect(token.tokenType).to.equal('CustomType'); + expect(token.refreshToken).to.equal('refresh'); + expect(token.scopes).to.deep.equal(['sql']); + }); + }); + + describe('getToken', () => { + it('should always return the same token', async () => { + const provider = new StaticTokenProvider('my-token'); + + const token1 = await provider.getToken(); + const token2 = await provider.getToken(); + + expect(token1).to.equal(token2); + expect(token1.accessToken).to.equal('my-token'); + }); + }); + + describe('getName', () => { + it('should return provider name', () => { + const provider = new StaticTokenProvider('my-token'); + expect(provider.getName()).to.equal('StaticTokenProvider'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/Token.test.ts b/tests/unit/connection/auth/tokenProvider/Token.test.ts new file mode 100644 index 00000000..febaf712 --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/Token.test.ts @@ -0,0 +1,162 @@ +import { expect } from 'chai'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('Token', () => { + describe('constructor', () => { + it('should create token with access token only', () => { + const token = new Token('test-access-token'); + expect(token.accessToken).to.equal('test-access-token'); + expect(token.tokenType).to.equal('Bearer'); + expect(token.expiresAt).to.be.undefined; + expect(token.refreshToken).to.be.undefined; + expect(token.scopes).to.be.undefined; + }); + + it('should create token with all options', () => { + const expiresAt = new Date('2025-01-01T00:00:00Z'); + const token = new Token('test-access-token', { + tokenType: 'CustomType', + expiresAt, + refreshToken: 'refresh-token', + scopes: ['read', 'write'], + }); + expect(token.accessToken).to.equal('test-access-token'); + expect(token.tokenType).to.equal('CustomType'); + expect(token.expiresAt).to.deep.equal(expiresAt); + expect(token.refreshToken).to.equal('refresh-token'); + expect(token.scopes).to.deep.equal(['read', 'write']); + }); + }); + + describe('isExpired', () => { + it('should return false when expiration is not set', () => { + const token = new Token('test-token'); + expect(token.isExpired()).to.be.false; + }); + + it('should return true when token is expired', () => { + const expiresAt = new Date(Date.now() - 60000); // 1 minute ago + const token = new Token('test-token', { expiresAt }); + expect(token.isExpired()).to.be.true; + }); + + it('should return false when token is not expired', () => { + const expiresAt = new Date(Date.now() + 300000); // 5 minutes from now + const token = new Token('test-token', { expiresAt }); + expect(token.isExpired()).to.be.false; + }); + + it('should return true when within 30 second safety buffer', () => { + const expiresAt = new Date(Date.now() + 20000); // 20 seconds from now + const token = new Token('test-token', { expiresAt }); + expect(token.isExpired()).to.be.true; + }); + }); + + describe('setAuthHeader', () => { + it('should set Authorization header with default Bearer type', () => { + const token = new Token('my-token'); + const headers = token.setAuthHeader({}); + expect(headers).to.deep.equal({ Authorization: 'Bearer my-token' }); + }); + + it('should set Authorization header with custom type', () => { + const token = new Token('my-token', { tokenType: 'Basic' }); + const headers = token.setAuthHeader({}); + expect(headers).to.deep.equal({ Authorization: 'Basic my-token' }); + }); + + it('should preserve existing headers', () => { + const token = new Token('my-token'); + const headers = token.setAuthHeader({ 'Content-Type': 'application/json' }); + expect(headers).to.deep.equal({ + 'Content-Type': 'application/json', + Authorization: 'Bearer my-token', + }); + }); + }); + + describe('fromJWT', () => { + it('should extract expiration from JWT payload', () => { + const exp = Math.floor(Date.now() / 1000) + 3600; // 1 hour from now + const jwt = createJWT({ exp, iss: 'test-issuer' }); + const token = Token.fromJWT(jwt); + + expect(token.accessToken).to.equal(jwt); + expect(token.tokenType).to.equal('Bearer'); + expect(token.expiresAt).to.be.instanceOf(Date); + expect(Math.floor(token.expiresAt!.getTime() / 1000)).to.equal(exp); + }); + + it('should handle JWT without expiration', () => { + const jwt = createJWT({ iss: 'test-issuer' }); + const token = Token.fromJWT(jwt); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.undefined; + }); + + it('should handle malformed JWT gracefully', () => { + const token = Token.fromJWT('not-a-valid-jwt'); + expect(token.accessToken).to.equal('not-a-valid-jwt'); + expect(token.expiresAt).to.be.undefined; + }); + + it('should handle JWT with invalid base64 payload', () => { + const token = Token.fromJWT('header.!!!invalid-base64!!!.signature'); + expect(token.accessToken).to.equal('header.!!!invalid-base64!!!.signature'); + expect(token.expiresAt).to.be.undefined; + }); + + it('should apply custom options', () => { + const jwt = createJWT({ exp: Math.floor(Date.now() / 1000) + 3600 }); + const token = Token.fromJWT(jwt, { + tokenType: 'CustomType', + refreshToken: 'refresh', + scopes: ['sql'], + }); + + expect(token.tokenType).to.equal('CustomType'); + expect(token.refreshToken).to.equal('refresh'); + expect(token.scopes).to.deep.equal(['sql']); + }); + }); + + describe('toJSON', () => { + it('should serialize token to JSON', () => { + const expiresAt = new Date('2025-01-01T00:00:00Z'); + const token = new Token('test-token', { + tokenType: 'Bearer', + expiresAt, + refreshToken: 'refresh', + scopes: ['read'], + }); + + const json = token.toJSON(); + expect(json).to.deep.equal({ + accessToken: 'test-token', + tokenType: 'Bearer', + expiresAt: '2025-01-01T00:00:00.000Z', + refreshToken: 'refresh', + scopes: ['read'], + }); + }); + + it('should handle undefined optional fields', () => { + const token = new Token('test-token'); + const json = token.toJSON(); + + expect(json.accessToken).to.equal('test-token'); + expect(json.tokenType).to.equal('Bearer'); + expect(json.expiresAt).to.be.undefined; + expect(json.refreshToken).to.be.undefined; + expect(json.scopes).to.be.undefined; + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/TokenProviderAuthenticator.test.ts b/tests/unit/connection/auth/tokenProvider/TokenProviderAuthenticator.test.ts new file mode 100644 index 00000000..a5a3963e --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/TokenProviderAuthenticator.test.ts @@ -0,0 +1,108 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import TokenProviderAuthenticator from '../../../../../lib/connection/auth/tokenProvider/TokenProviderAuthenticator'; +import ITokenProvider from '../../../../../lib/connection/auth/tokenProvider/ITokenProvider'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; +import ClientContextStub from '../../../.stubs/ClientContextStub'; + +class MockTokenProvider implements ITokenProvider { + private token: Token; + + private name: string; + + constructor(accessToken: string, name: string = 'MockTokenProvider') { + this.token = new Token(accessToken); + this.name = name; + } + + async getToken(): Promise { + return this.token; + } + + getName(): string { + return this.name; + } + + setToken(token: Token): void { + this.token = token; + } +} + +describe('TokenProviderAuthenticator', () => { + let context: ClientContextStub; + + beforeEach(() => { + context = new ClientContextStub(); + }); + + describe('authenticate', () => { + it('should return headers with Authorization', async () => { + const provider = new MockTokenProvider('my-access-token'); + const authenticator = new TokenProviderAuthenticator(provider, context); + + const headers = await authenticator.authenticate(); + + expect(headers).to.deep.equal({ + Authorization: 'Bearer my-access-token', + }); + }); + + it('should include additional headers', async () => { + const provider = new MockTokenProvider('my-access-token'); + const authenticator = new TokenProviderAuthenticator(provider, context, { + 'Content-Type': 'application/json', + 'X-Custom-Header': 'custom-value', + }); + + const headers = await authenticator.authenticate(); + + expect(headers).to.deep.equal({ + 'Content-Type': 'application/json', + 'X-Custom-Header': 'custom-value', + Authorization: 'Bearer my-access-token', + }); + }); + + it('should use token type from token', async () => { + const provider = new MockTokenProvider('my-access-token'); + provider.setToken(new Token('my-token', { tokenType: 'Basic' })); + const authenticator = new TokenProviderAuthenticator(provider, context); + + const headers = await authenticator.authenticate(); + + expect(headers).to.deep.equal({ + Authorization: 'Basic my-token', + }); + }); + + it('should call provider getToken', async () => { + const provider = new MockTokenProvider('my-access-token'); + const getTokenSpy = sinon.spy(provider, 'getToken'); + const authenticator = new TokenProviderAuthenticator(provider, context); + + await authenticator.authenticate(); + + expect(getTokenSpy.calledOnce).to.be.true; + }); + + it('should propagate errors from provider', async () => { + const error = new Error('Failed to get token'); + const provider: ITokenProvider = { + async getToken() { + throw error; + }, + getName() { + return 'ErrorProvider'; + }, + }; + const authenticator = new TokenProviderAuthenticator(provider, context); + + try { + await authenticator.authenticate(); + expect.fail('Should have thrown an error'); + } catch (e) { + expect(e).to.equal(error); + } + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/utils.test.ts b/tests/unit/connection/auth/tokenProvider/utils.test.ts new file mode 100644 index 00000000..80a91f85 --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/utils.test.ts @@ -0,0 +1,90 @@ +import { expect } from 'chai'; +import { decodeJWT, getJWTIssuer, isSameHost } from '../../../../../lib/connection/auth/tokenProvider/utils'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('Token Provider Utils', () => { + describe('decodeJWT', () => { + it('should decode valid JWT payload', () => { + const payload = { iss: 'test-issuer', sub: 'user123', exp: 1234567890 }; + const jwt = createJWT(payload); + + const decoded = decodeJWT(jwt); + + expect(decoded).to.deep.equal(payload); + }); + + it('should return null for malformed JWT', () => { + expect(decodeJWT('not-a-jwt')).to.be.null; + expect(decodeJWT('')).to.be.null; + }); + + it('should return null for JWT with invalid base64 payload', () => { + expect(decodeJWT('header.!!!invalid!!!.signature')).to.be.null; + }); + + it('should return null for JWT with non-JSON payload', () => { + const header = Buffer.from('{}').toString('base64'); + const body = Buffer.from('not json').toString('base64'); + expect(decodeJWT(`${header}.${body}.sig`)).to.be.null; + }); + }); + + describe('getJWTIssuer', () => { + it('should extract issuer from JWT', () => { + const jwt = createJWT({ iss: 'https://my-issuer.com', sub: 'user' }); + expect(getJWTIssuer(jwt)).to.equal('https://my-issuer.com'); + }); + + it('should return null if no issuer claim', () => { + const jwt = createJWT({ sub: 'user' }); + expect(getJWTIssuer(jwt)).to.be.null; + }); + + it('should return null if issuer is not a string', () => { + const jwt = createJWT({ iss: 123 }); + expect(getJWTIssuer(jwt)).to.be.null; + }); + + it('should return null for invalid JWT', () => { + expect(getJWTIssuer('not-a-jwt')).to.be.null; + }); + }); + + describe('isSameHost', () => { + it('should match identical hosts', () => { + expect(isSameHost('example.com', 'example.com')).to.be.true; + }); + + it('should match hosts with different protocols', () => { + expect(isSameHost('https://example.com', 'http://example.com')).to.be.true; + }); + + it('should match hosts ignoring ports', () => { + expect(isSameHost('example.com', 'example.com:443')).to.be.true; + expect(isSameHost('https://example.com:443', 'example.com')).to.be.true; + }); + + it('should match hosts case-insensitively', () => { + expect(isSameHost('Example.COM', 'example.com')).to.be.true; + }); + + it('should not match different hosts', () => { + expect(isSameHost('example.com', 'other.com')).to.be.false; + expect(isSameHost('sub.example.com', 'example.com')).to.be.false; + }); + + it('should handle full URLs', () => { + expect(isSameHost('https://my-workspace.cloud.databricks.com/path', 'my-workspace.cloud.databricks.com')).to.be + .true; + }); + + it('should return false for invalid inputs', () => { + expect(isSameHost('', '')).to.be.false; + }); + }); +});