diff --git a/examples/tokenFederation/customTokenProvider.d.ts b/examples/tokenFederation/customTokenProvider.d.ts new file mode 100644 index 00000000..6934004f --- /dev/null +++ b/examples/tokenFederation/customTokenProvider.d.ts @@ -0,0 +1,9 @@ +/** + * Example: Custom Token Provider Implementation + * + * This example demonstrates how to create a custom token provider by + * implementing the ITokenProvider interface. This gives you full control + * over token management, including custom caching, refresh logic, and + * error handling. + */ +export {}; diff --git a/examples/tokenFederation/customTokenProvider.js b/examples/tokenFederation/customTokenProvider.js new file mode 100644 index 00000000..0fd798c8 --- /dev/null +++ b/examples/tokenFederation/customTokenProvider.js @@ -0,0 +1,144 @@ +'use strict'; +/** + * Example: Custom Token Provider Implementation + * + * This example demonstrates how to create a custom token provider by + * implementing the ITokenProvider interface. This gives you full control + * over token management, including custom caching, refresh logic, and + * error handling. + */ +var __createBinding = + (this && this.__createBinding) || + (Object.create + ? function (o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ('get' in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { + enumerable: true, + get: function () { + return m[k]; + }, + }; + } + Object.defineProperty(o, k2, desc); + } + : function (o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; + }); +var __setModuleDefault = + (this && this.__setModuleDefault) || + (Object.create + ? function (o, v) { + Object.defineProperty(o, 'default', { enumerable: true, value: v }); + } + : function (o, v) { + o['default'] = v; + }); +var __importStar = + (this && this.__importStar) || + function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) + for (var k in mod) + if (k !== 'default' && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; + }; +Object.defineProperty(exports, '__esModule', { value: true }); +const sql_1 = require('@databricks/sql'); +const tokenProvider_1 = require('../../lib/connection/auth/tokenProvider'); +/** + * Custom token provider that refreshes tokens from a custom OAuth server. + */ +class CustomOAuthTokenProvider { + constructor(oauthServerUrl, clientId, clientSecret) { + this.oauthServerUrl = oauthServerUrl; + this.clientId = clientId; + this.clientSecret = clientSecret; + } + async getToken() { + var _a; + console.log('Fetching token from custom OAuth server...'); + // Example: Fetch token using client credentials grant + const response = await fetch(`${this.oauthServerUrl}/oauth/token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'client_credentials', + client_id: this.clientId, + client_secret: this.clientSecret, + scope: 'sql', + }).toString(), + }); + if (!response.ok) { + throw new Error(`OAuth token request failed: ${response.status}`); + } + const data = await response.json(); + // Calculate expiration + let expiresAt; + if (typeof data.expires_in === 'number') { + expiresAt = new Date(Date.now() + data.expires_in * 1000); + } + return new tokenProvider_1.Token(data.access_token, { + tokenType: (_a = data.token_type) !== null && _a !== void 0 ? _a : 'Bearer', + expiresAt, + }); + } + getName() { + return 'CustomOAuthTokenProvider'; + } +} +/** + * Simple token provider that reads from a file (for development/testing). + */ +class FileTokenProvider { + constructor(filePath) { + this.filePath = filePath; + } + async getToken() { + const fs = await Promise.resolve().then(() => __importStar(require('fs/promises'))); + const tokenData = await fs.readFile(this.filePath, 'utf-8'); + const parsed = JSON.parse(tokenData); + return tokenProvider_1.Token.fromJWT(parsed.access_token, { + refreshToken: parsed.refresh_token, + }); + } + getName() { + return 'FileTokenProvider'; + } +} +async function main() { + const host = process.env.DATABRICKS_HOST; + const path = process.env.DATABRICKS_HTTP_PATH; + const client = new sql_1.DBSQLClient(); + // Option 1: Use a custom OAuth token provider + const oauthProvider = new CustomOAuthTokenProvider( + process.env.OAUTH_SERVER_URL, + process.env.OAUTH_CLIENT_ID, + process.env.OAUTH_CLIENT_SECRET, + ); + await client.connect({ + host, + path, + authType: 'token-provider', + tokenProvider: oauthProvider, + // Optionally enable federation if your OAuth server issues non-Databricks tokens + enableTokenFederation: true, + }); + console.log('Connected successfully with custom token provider'); + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT 1 AS result'); + const result = await operation.fetchAll(); + console.log('Query result:', result); + await operation.close(); + await session.close(); + await client.close(); +} +main().catch(console.error); +//# sourceMappingURL=customTokenProvider.js.map diff --git a/examples/tokenFederation/customTokenProvider.js.map b/examples/tokenFederation/customTokenProvider.js.map new file mode 100644 index 00000000..0a822b74 --- /dev/null +++ b/examples/tokenFederation/customTokenProvider.js.map @@ -0,0 +1 @@ +{"version":3,"file":"customTokenProvider.js","sourceRoot":"","sources":["customTokenProvider.ts"],"names":[],"mappings":";AAAA;;;;;;;GAOG;;;;;;;;;;;;;;;;;;;;;;;;;AAEH,yCAA8C;AAC9C,2EAGiD;AAEjD;;GAEG;AACH,MAAM,wBAAwB;IAK5B,YAAY,cAAsB,EAAE,QAAgB,EAAE,YAAoB;QACxE,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,YAAY,GAAG,YAAY,CAAC;IACnC,CAAC;IAED,KAAK,CAAC,QAAQ;;QACZ,OAAO,CAAC,GAAG,CAAC,4CAA4C,CAAC,CAAC;QAE1D,sDAAsD;QACtD,MAAM,QAAQ,GAAG,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,cAAc,cAAc,EAAE;YACjE,MAAM,EAAE,MAAM;YACd,OAAO,EAAE;gBACP,cAAc,EAAE,mCAAmC;aACpD;YACD,IAAI,EAAE,IAAI,eAAe,CAAC;gBACxB,UAAU,EAAE,oBAAoB;gBAChC,SAAS,EAAE,IAAI,CAAC,QAAQ;gBACxB,aAAa,EAAE,IAAI,CAAC,YAAY;gBAChC,KAAK,EAAE,KAAK;aACb,CAAC,CAAC,QAAQ,EAAE;SACd,CAAC,CAAC;QAEH,IAAI,CAAC,QAAQ,CAAC,EAAE,EAAE;YAChB,MAAM,IAAI,KAAK,CAAC,+BAA+B,QAAQ,CAAC,MAAM,EAAE,CAAC,CAAC;SACnE;QAED,MAAM,IAAI,GAAG,MAAM,QAAQ,CAAC,IAAI,EAI/B,CAAC;QAEF,uBAAuB;QACvB,IAAI,SAA2B,CAAC;QAChC,IAAI,OAAO,IAAI,CAAC,UAAU,KAAK,QAAQ,EAAE;YACvC,SAAS,GAAG,IAAI,IAAI,CAAC,IAAI,CAAC,GAAG,EAAE,GAAG,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,CAAC;SAC3D;QAED,OAAO,IAAI,qBAAK,CAAC,IAAI,CAAC,YAAY,EAAE;YAClC,SAAS,EAAE,MAAA,IAAI,CAAC,UAAU,mCAAI,QAAQ;YACtC,SAAS;SACV,CAAC,CAAC;IACL,CAAC;IAED,OAAO;QACL,OAAO,0BAA0B,CAAC;IACpC,CAAC;CACF;AAED;;GAEG;AACH,MAAM,iBAAiB;IAGrB,YAAY,QAAgB;QAC1B,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;IAC3B,CAAC;IAED,KAAK,CAAC,QAAQ;QACZ,MAAM,EAAE,GAAG,wDAAa,aAAa,GAAC,CAAC;QACvC,MAAM,SAAS,GAAG,MAAM,EAAE,CAAC,QAAQ,CAAC,IAAI,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC;QAC5D,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC;QAErC,OAAO,qBAAK,CAAC,OAAO,CAAC,MAAM,CAAC,YAAY,EAAE;YACxC,YAAY,EAAE,MAAM,CAAC,aAAa;SACnC,CAAC,CAAC;IACL,CAAC;IAED,OAAO;QACL,OAAO,mBAAmB,CAAC;IAC7B,CAAC;CACF;AAED,KAAK,UAAU,IAAI;IACjB,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,eAAgB,CAAC;IAC1C,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,oBAAqB,CAAC;IAE/C,MAAM,MAAM,GAAG,IAAI,iBAAW,EAAE,CAAC;IAEjC,8CAA8C;IAC9C,MAAM,aAAa,GAAG,IAAI,wBAAwB,CAChD,OAAO,CAAC,GAAG,CAAC,gBAAiB,EAC7B,OAAO,CAAC,GAAG,CAAC,eAAgB,EAC5B,OAAO,CAAC,GAAG,CAAC,mBAAoB,CACjC,CAAC;IAEF,MAAM,MAAM,CAAC,OAAO,CAAC;QACnB,IAAI;QACJ,IAAI;QACJ,QAAQ,EAAE,gBAAgB;QAC1B,aAAa,EAAE,aAAa;QAC5B,iFAAiF;QACjF,qBAAqB,EAAE,IAAI;KAC5B,CAAC,CAAC;IAEH,OAAO,CAAC,GAAG,CAAC,mDAAmD,CAAC,CAAC;IAEjE,iCAAiC;IACjC,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,WAAW,EAAE,CAAC;IAC3C,MAAM,SAAS,GAAG,MAAM,OAAO,CAAC,gBAAgB,CAAC,oBAAoB,CAAC,CAAC;IACvE,MAAM,MAAM,GAAG,MAAM,SAAS,CAAC,QAAQ,EAAE,CAAC;IAE1C,OAAO,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;IAErC,MAAM,SAAS,CAAC,KAAK,EAAE,CAAC;IACxB,MAAM,OAAO,CAAC,KAAK,EAAE,CAAC;IACtB,MAAM,MAAM,CAAC,KAAK,EAAE,CAAC;AACvB,CAAC;AAED,IAAI,EAAE,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC"} \ No newline at end of file diff --git a/examples/tokenFederation/customTokenProvider.ts b/examples/tokenFederation/customTokenProvider.ts new file mode 100644 index 00000000..139e5f6f --- /dev/null +++ b/examples/tokenFederation/customTokenProvider.ts @@ -0,0 +1,132 @@ +/** + * Example: Custom Token Provider Implementation + * + * This example demonstrates how to create a custom token provider by + * implementing the ITokenProvider interface. This gives you full control + * over token management, including custom caching, refresh logic, and + * error handling. + */ + +import { DBSQLClient } from '@databricks/sql'; +import { ITokenProvider, Token } from '../../lib/connection/auth/tokenProvider'; + +/** + * Custom token provider that refreshes tokens from a custom OAuth server. + */ +class CustomOAuthTokenProvider implements ITokenProvider { + private readonly oauthServerUrl: string; + private readonly clientId: string; + private readonly clientSecret: string; + + constructor(oauthServerUrl: string, clientId: string, clientSecret: string) { + this.oauthServerUrl = oauthServerUrl; + this.clientId = clientId; + this.clientSecret = clientSecret; + } + + async getToken(): Promise { + console.log('Fetching token from custom OAuth server...'); + + // Example: Fetch token using client credentials grant + const response = await fetch(`${this.oauthServerUrl}/oauth/token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'client_credentials', + client_id: this.clientId, + client_secret: this.clientSecret, + scope: 'sql', + }).toString(), + }); + + if (!response.ok) { + throw new Error(`OAuth token request failed: ${response.status}`); + } + + const data = (await response.json()) as { + access_token: string; + token_type?: string; + expires_in?: number; + }; + + // Calculate expiration + let expiresAt: Date | undefined; + if (typeof data.expires_in === 'number') { + expiresAt = new Date(Date.now() + data.expires_in * 1000); + } + + return new Token(data.access_token, { + tokenType: data.token_type ?? 'Bearer', + expiresAt, + }); + } + + getName(): string { + return 'CustomOAuthTokenProvider'; + } +} + +/** + * Simple token provider that reads from a file (for development/testing). + */ +class FileTokenProvider implements ITokenProvider { + private readonly filePath: string; + + constructor(filePath: string) { + this.filePath = filePath; + } + + async getToken(): Promise { + const fs = await import('fs/promises'); + const tokenData = await fs.readFile(this.filePath, 'utf-8'); + const parsed = JSON.parse(tokenData); + + return Token.fromJWT(parsed.access_token, { + refreshToken: parsed.refresh_token, + }); + } + + getName(): string { + return 'FileTokenProvider'; + } +} + +async function main() { + const host = process.env.DATABRICKS_HOST!; + const path = process.env.DATABRICKS_HTTP_PATH!; + + const client = new DBSQLClient(); + + // Option 1: Use a custom OAuth token provider + const oauthProvider = new CustomOAuthTokenProvider( + process.env.OAUTH_SERVER_URL!, + process.env.OAUTH_CLIENT_ID!, + process.env.OAUTH_CLIENT_SECRET!, + ); + + await client.connect({ + host, + path, + authType: 'token-provider', + tokenProvider: oauthProvider, + // Optionally enable federation if your OAuth server issues non-Databricks tokens + enableTokenFederation: true, + }); + + console.log('Connected successfully with custom token provider'); + + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT 1 AS result'); + const result = await operation.fetchAll(); + + console.log('Query result:', result); + + await operation.close(); + await session.close(); + await client.close(); +} + +main().catch(console.error); diff --git a/examples/tokenFederation/externalToken.d.ts b/examples/tokenFederation/externalToken.d.ts new file mode 100644 index 00000000..ed8d3efb --- /dev/null +++ b/examples/tokenFederation/externalToken.d.ts @@ -0,0 +1,8 @@ +/** + * Example: Using an external token provider + * + * This example demonstrates how to use a callback function to provide + * tokens dynamically. This is useful for integrating with secret managers, + * vaults, or other token sources that may refresh tokens. + */ +export {}; diff --git a/examples/tokenFederation/externalToken.js b/examples/tokenFederation/externalToken.js new file mode 100644 index 00000000..8db53859 --- /dev/null +++ b/examples/tokenFederation/externalToken.js @@ -0,0 +1,45 @@ +'use strict'; +/** + * Example: Using an external token provider + * + * This example demonstrates how to use a callback function to provide + * tokens dynamically. This is useful for integrating with secret managers, + * vaults, or other token sources that may refresh tokens. + */ +Object.defineProperty(exports, '__esModule', { value: true }); +const sql_1 = require('@databricks/sql'); +// Simulate fetching a token from a secret manager or vault +async function fetchTokenFromVault() { + // In a real application, this would fetch from AWS Secrets Manager, + // Azure Key Vault, HashiCorp Vault, or another secret manager + console.log('Fetching token from vault...'); + // Simulated token - replace with actual vault integration + const token = process.env.DATABRICKS_TOKEN; + return token; +} +async function main() { + const host = process.env.DATABRICKS_HOST; + const path = process.env.DATABRICKS_HTTP_PATH; + const client = new sql_1.DBSQLClient(); + // Connect using an external token provider + // The callback will be called each time a new token is needed + // Note: The token is automatically cached, so the callback won't be + // called on every request + await client.connect({ + host, + path, + authType: 'external-token', + getToken: fetchTokenFromVault, + }); + console.log('Connected successfully with external token provider'); + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT current_user() AS user'); + const result = await operation.fetchAll(); + console.log('Query result:', result); + await operation.close(); + await session.close(); + await client.close(); +} +main().catch(console.error); +//# sourceMappingURL=externalToken.js.map diff --git a/examples/tokenFederation/externalToken.js.map b/examples/tokenFederation/externalToken.js.map new file mode 100644 index 00000000..a652beb2 --- /dev/null +++ b/examples/tokenFederation/externalToken.js.map @@ -0,0 +1 @@ +{"version":3,"file":"externalToken.js","sourceRoot":"","sources":["externalToken.ts"],"names":[],"mappings":";AAAA;;;;;;GAMG;;AAEH,yCAA8C;AAE9C,2DAA2D;AAC3D,KAAK,UAAU,mBAAmB;IAChC,oEAAoE;IACpE,8DAA8D;IAC9D,OAAO,CAAC,GAAG,CAAC,8BAA8B,CAAC,CAAC;IAE5C,0DAA0D;IAC1D,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,gBAAiB,CAAC;IAC5C,OAAO,KAAK,CAAC;AACf,CAAC;AAED,KAAK,UAAU,IAAI;IACjB,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,eAAgB,CAAC;IAC1C,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,oBAAqB,CAAC;IAE/C,MAAM,MAAM,GAAG,IAAI,iBAAW,EAAE,CAAC;IAEjC,2CAA2C;IAC3C,8DAA8D;IAC9D,oEAAoE;IACpE,0BAA0B;IAC1B,MAAM,MAAM,CAAC,OAAO,CAAC;QACnB,IAAI;QACJ,IAAI;QACJ,QAAQ,EAAE,gBAAgB;QAC1B,QAAQ,EAAE,mBAAmB;KAC9B,CAAC,CAAC;IAEH,OAAO,CAAC,GAAG,CAAC,qDAAqD,CAAC,CAAC;IAEnE,iCAAiC;IACjC,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,WAAW,EAAE,CAAC;IAC3C,MAAM,SAAS,GAAG,MAAM,OAAO,CAAC,gBAAgB,CAAC,+BAA+B,CAAC,CAAC;IAClF,MAAM,MAAM,GAAG,MAAM,SAAS,CAAC,QAAQ,EAAE,CAAC;IAE1C,OAAO,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;IAErC,MAAM,SAAS,CAAC,KAAK,EAAE,CAAC;IACxB,MAAM,OAAO,CAAC,KAAK,EAAE,CAAC;IACtB,MAAM,MAAM,CAAC,KAAK,EAAE,CAAC;AACvB,CAAC;AAED,IAAI,EAAE,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC"} \ No newline at end of file diff --git a/examples/tokenFederation/externalToken.ts b/examples/tokenFederation/externalToken.ts new file mode 100644 index 00000000..224da6de --- /dev/null +++ b/examples/tokenFederation/externalToken.ts @@ -0,0 +1,53 @@ +/** + * Example: Using an external token provider + * + * This example demonstrates how to use a callback function to provide + * tokens dynamically. This is useful for integrating with secret managers, + * vaults, or other token sources that may refresh tokens. + */ + +import { DBSQLClient } from '@databricks/sql'; + +// Simulate fetching a token from a secret manager or vault +async function fetchTokenFromVault(): Promise { + // In a real application, this would fetch from AWS Secrets Manager, + // Azure Key Vault, HashiCorp Vault, or another secret manager + console.log('Fetching token from vault...'); + + // Simulated token - replace with actual vault integration + const token = process.env.DATABRICKS_TOKEN!; + return token; +} + +async function main() { + const host = process.env.DATABRICKS_HOST!; + const path = process.env.DATABRICKS_HTTP_PATH!; + + const client = new DBSQLClient(); + + // Connect using an external token provider + // The callback will be called each time a new token is needed + // Note: The token is automatically cached, so the callback won't be + // called on every request + await client.connect({ + host, + path, + authType: 'external-token', + getToken: fetchTokenFromVault, + }); + + console.log('Connected successfully with external token provider'); + + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT current_user() AS user'); + const result = await operation.fetchAll(); + + console.log('Query result:', result); + + await operation.close(); + await session.close(); + await client.close(); +} + +main().catch(console.error); diff --git a/examples/tokenFederation/federation.d.ts b/examples/tokenFederation/federation.d.ts new file mode 100644 index 00000000..326bae78 --- /dev/null +++ b/examples/tokenFederation/federation.d.ts @@ -0,0 +1,11 @@ +/** + * Example: Token Federation with an External Identity Provider + * + * This example demonstrates how to use token federation to automatically + * exchange tokens from external identity providers (Azure AD, Google, Okta, + * Auth0, AWS Cognito, GitHub) for Databricks-compatible tokens. + * + * Token federation uses RFC 8693 (OAuth 2.0 Token Exchange) to exchange + * the external JWT token for a Databricks access token. + */ +export {}; diff --git a/examples/tokenFederation/federation.js b/examples/tokenFederation/federation.js new file mode 100644 index 00000000..a33ef2cb --- /dev/null +++ b/examples/tokenFederation/federation.js @@ -0,0 +1,69 @@ +'use strict'; +/** + * Example: Token Federation with an External Identity Provider + * + * This example demonstrates how to use token federation to automatically + * exchange tokens from external identity providers (Azure AD, Google, Okta, + * Auth0, AWS Cognito, GitHub) for Databricks-compatible tokens. + * + * Token federation uses RFC 8693 (OAuth 2.0 Token Exchange) to exchange + * the external JWT token for a Databricks access token. + */ +Object.defineProperty(exports, '__esModule', { value: true }); +const sql_1 = require('@databricks/sql'); +// Example: Fetch a token from Azure AD +// In a real application, you would use the Azure SDK or similar +async function getAzureADToken() { + // Example using @azure/identity: + // + // import { DefaultAzureCredential } from '@azure/identity'; + // const credential = new DefaultAzureCredential(); + // const token = await credential.getToken('https://your-scope/.default'); + // return token.token; + // For this example, we use an environment variable + const token = process.env.AZURE_AD_TOKEN; + console.log('Fetched token from Azure AD'); + return token; +} +// Example: Fetch a token from Google +async function getGoogleToken() { + // Example using google-auth-library: + // + // import { GoogleAuth } from 'google-auth-library'; + // const auth = new GoogleAuth(); + // const client = await auth.getClient(); + // const token = await client.getAccessToken(); + // return token.token; + const token = process.env.GOOGLE_TOKEN; + console.log('Fetched token from Google'); + return token; +} +async function main() { + const host = process.env.DATABRICKS_HOST; + const path = process.env.DATABRICKS_HTTP_PATH; + const client = new sql_1.DBSQLClient(); + // Connect using token federation + // The driver will automatically: + // 1. Get the token from the callback + // 2. Check if the token's issuer matches the Databricks host + // 3. If not, exchange the token for a Databricks token via RFC 8693 + // 4. Cache the result for subsequent requests + await client.connect({ + host, + path, + authType: 'external-token', + getToken: getAzureADToken, + enableTokenFederation: true, + }); + console.log('Connected successfully with token federation'); + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT current_user() AS user'); + const result = await operation.fetchAll(); + console.log('Query result:', result); + await operation.close(); + await session.close(); + await client.close(); +} +main().catch(console.error); +//# sourceMappingURL=federation.js.map diff --git a/examples/tokenFederation/federation.js.map b/examples/tokenFederation/federation.js.map new file mode 100644 index 00000000..2c516ffb --- /dev/null +++ b/examples/tokenFederation/federation.js.map @@ -0,0 +1 @@ +{"version":3,"file":"federation.js","sourceRoot":"","sources":["federation.ts"],"names":[],"mappings":";AAAA;;;;;;;;;GASG;;AAEH,yCAA8C;AAE9C,uCAAuC;AACvC,gEAAgE;AAChE,KAAK,UAAU,eAAe;IAC5B,iCAAiC;IACjC,EAAE;IACF,4DAA4D;IAC5D,mDAAmD;IACnD,0EAA0E;IAC1E,sBAAsB;IAEtB,mDAAmD;IACnD,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,cAAe,CAAC;IAC1C,OAAO,CAAC,GAAG,CAAC,6BAA6B,CAAC,CAAC;IAC3C,OAAO,KAAK,CAAC;AACf,CAAC;AAED,qCAAqC;AACrC,KAAK,UAAU,cAAc;IAC3B,qCAAqC;IACrC,EAAE;IACF,oDAAoD;IACpD,iCAAiC;IACjC,yCAAyC;IACzC,+CAA+C;IAC/C,sBAAsB;IAEtB,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,YAAa,CAAC;IACxC,OAAO,CAAC,GAAG,CAAC,2BAA2B,CAAC,CAAC;IACzC,OAAO,KAAK,CAAC;AACf,CAAC;AAED,KAAK,UAAU,IAAI;IACjB,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,eAAgB,CAAC;IAC1C,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,oBAAqB,CAAC;IAE/C,MAAM,MAAM,GAAG,IAAI,iBAAW,EAAE,CAAC;IAEjC,iCAAiC;IACjC,iCAAiC;IACjC,qCAAqC;IACrC,6DAA6D;IAC7D,oEAAoE;IACpE,8CAA8C;IAC9C,MAAM,MAAM,CAAC,OAAO,CAAC;QACnB,IAAI;QACJ,IAAI;QACJ,QAAQ,EAAE,gBAAgB;QAC1B,QAAQ,EAAE,eAAe;QACzB,qBAAqB,EAAE,IAAI;KAC5B,CAAC,CAAC;IAEH,OAAO,CAAC,GAAG,CAAC,8CAA8C,CAAC,CAAC;IAE5D,iCAAiC;IACjC,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,WAAW,EAAE,CAAC;IAC3C,MAAM,SAAS,GAAG,MAAM,OAAO,CAAC,gBAAgB,CAAC,+BAA+B,CAAC,CAAC;IAClF,MAAM,MAAM,GAAG,MAAM,SAAS,CAAC,QAAQ,EAAE,CAAC;IAE1C,OAAO,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;IAErC,MAAM,SAAS,CAAC,KAAK,EAAE,CAAC;IACxB,MAAM,OAAO,CAAC,KAAK,EAAE,CAAC;IACtB,MAAM,MAAM,CAAC,KAAK,EAAE,CAAC;AACvB,CAAC;AAED,IAAI,EAAE,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC"} \ No newline at end of file diff --git a/examples/tokenFederation/federation.ts b/examples/tokenFederation/federation.ts new file mode 100644 index 00000000..69021052 --- /dev/null +++ b/examples/tokenFederation/federation.ts @@ -0,0 +1,79 @@ +/** + * Example: Token Federation with an External Identity Provider + * + * This example demonstrates how to use token federation to automatically + * exchange tokens from external identity providers (Azure AD, Google, Okta, + * Auth0, AWS Cognito, GitHub) for Databricks-compatible tokens. + * + * Token federation uses RFC 8693 (OAuth 2.0 Token Exchange) to exchange + * the external JWT token for a Databricks access token. + */ + +import { DBSQLClient } from '@databricks/sql'; + +// Example: Fetch a token from Azure AD +// In a real application, you would use the Azure SDK or similar +async function getAzureADToken(): Promise { + // Example using @azure/identity: + // + // import { DefaultAzureCredential } from '@azure/identity'; + // const credential = new DefaultAzureCredential(); + // const token = await credential.getToken('https://your-scope/.default'); + // return token.token; + + // For this example, we use an environment variable + const token = process.env.AZURE_AD_TOKEN!; + console.log('Fetched token from Azure AD'); + return token; +} + +// Example: Fetch a token from Google +async function getGoogleToken(): Promise { + // Example using google-auth-library: + // + // import { GoogleAuth } from 'google-auth-library'; + // const auth = new GoogleAuth(); + // const client = await auth.getClient(); + // const token = await client.getAccessToken(); + // return token.token; + + const token = process.env.GOOGLE_TOKEN!; + console.log('Fetched token from Google'); + return token; +} + +async function main() { + const host = process.env.DATABRICKS_HOST!; + const path = process.env.DATABRICKS_HTTP_PATH!; + + const client = new DBSQLClient(); + + // Connect using token federation + // The driver will automatically: + // 1. Get the token from the callback + // 2. Check if the token's issuer matches the Databricks host + // 3. If not, exchange the token for a Databricks token via RFC 8693 + // 4. Cache the result for subsequent requests + await client.connect({ + host, + path, + authType: 'external-token', + getToken: getAzureADToken, // or getGoogleToken, etc. + enableTokenFederation: true, + }); + + console.log('Connected successfully with token federation'); + + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT current_user() AS user'); + const result = await operation.fetchAll(); + + console.log('Query result:', result); + + await operation.close(); + await session.close(); + await client.close(); +} + +main().catch(console.error); diff --git a/examples/tokenFederation/m2mFederation.d.ts b/examples/tokenFederation/m2mFederation.d.ts new file mode 100644 index 00000000..b7e82302 --- /dev/null +++ b/examples/tokenFederation/m2mFederation.d.ts @@ -0,0 +1,11 @@ +/** + * Example: Machine-to-Machine (M2M) Token Federation with Service Principal + * + * This example demonstrates how to use token federation with a service + * principal or machine identity. This is useful for server-to-server + * authentication where there is no interactive user. + * + * When using M2M federation, you typically need to provide a client_id + * to identify the service principal to Databricks. + */ +export {}; diff --git a/examples/tokenFederation/m2mFederation.js b/examples/tokenFederation/m2mFederation.js new file mode 100644 index 00000000..0107b775 --- /dev/null +++ b/examples/tokenFederation/m2mFederation.js @@ -0,0 +1,57 @@ +'use strict'; +/** + * Example: Machine-to-Machine (M2M) Token Federation with Service Principal + * + * This example demonstrates how to use token federation with a service + * principal or machine identity. This is useful for server-to-server + * authentication where there is no interactive user. + * + * When using M2M federation, you typically need to provide a client_id + * to identify the service principal to Databricks. + */ +Object.defineProperty(exports, '__esModule', { value: true }); +const sql_1 = require('@databricks/sql'); +// Example: Fetch a service account token from your identity provider +async function getServiceAccountToken() { + // Example for Azure service principal: + // + // import { ClientSecretCredential } from '@azure/identity'; + // const credential = new ClientSecretCredential( + // process.env.AZURE_TENANT_ID!, + // process.env.AZURE_CLIENT_ID!, + // process.env.AZURE_CLIENT_SECRET! + // ); + // const token = await credential.getToken('https://your-scope/.default'); + // return token.token; + // For this example, we use an environment variable + const token = process.env.SERVICE_ACCOUNT_TOKEN; + console.log('Fetched service account token'); + return token; +} +async function main() { + const host = process.env.DATABRICKS_HOST; + const path = process.env.DATABRICKS_HTTP_PATH; + const clientId = process.env.DATABRICKS_CLIENT_ID; + const client = new sql_1.DBSQLClient(); + // Connect using M2M token federation + // The federationClientId identifies your service principal to Databricks + await client.connect({ + host, + path, + authType: 'external-token', + getToken: getServiceAccountToken, + enableTokenFederation: true, + federationClientId: clientId, // Required for M2M/SP federation + }); + console.log('Connected successfully with M2M token federation'); + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT current_user() AS user'); + const result = await operation.fetchAll(); + console.log('Query result:', result); + await operation.close(); + await session.close(); + await client.close(); +} +main().catch(console.error); +//# sourceMappingURL=m2mFederation.js.map diff --git a/examples/tokenFederation/m2mFederation.js.map b/examples/tokenFederation/m2mFederation.js.map new file mode 100644 index 00000000..6132d904 --- /dev/null +++ b/examples/tokenFederation/m2mFederation.js.map @@ -0,0 +1 @@ +{"version":3,"file":"m2mFederation.js","sourceRoot":"","sources":["m2mFederation.ts"],"names":[],"mappings":";AAAA;;;;;;;;;GASG;;AAEH,yCAA8C;AAE9C,qEAAqE;AACrE,KAAK,UAAU,sBAAsB;IACnC,uCAAuC;IACvC,EAAE;IACF,4DAA4D;IAC5D,iDAAiD;IACjD,kCAAkC;IAClC,kCAAkC;IAClC,qCAAqC;IACrC,KAAK;IACL,0EAA0E;IAC1E,sBAAsB;IAEtB,mDAAmD;IACnD,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,qBAAsB,CAAC;IACjD,OAAO,CAAC,GAAG,CAAC,+BAA+B,CAAC,CAAC;IAC7C,OAAO,KAAK,CAAC;AACf,CAAC;AAED,KAAK,UAAU,IAAI;IACjB,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,eAAgB,CAAC;IAC1C,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,oBAAqB,CAAC;IAC/C,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,oBAAqB,CAAC;IAEnD,MAAM,MAAM,GAAG,IAAI,iBAAW,EAAE,CAAC;IAEjC,qCAAqC;IACrC,yEAAyE;IACzE,MAAM,MAAM,CAAC,OAAO,CAAC;QACnB,IAAI;QACJ,IAAI;QACJ,QAAQ,EAAE,gBAAgB;QAC1B,QAAQ,EAAE,sBAAsB;QAChC,qBAAqB,EAAE,IAAI;QAC3B,kBAAkB,EAAE,QAAQ,EAAE,iCAAiC;KAChE,CAAC,CAAC;IAEH,OAAO,CAAC,GAAG,CAAC,kDAAkD,CAAC,CAAC;IAEhE,iCAAiC;IACjC,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,WAAW,EAAE,CAAC;IAC3C,MAAM,SAAS,GAAG,MAAM,OAAO,CAAC,gBAAgB,CAAC,+BAA+B,CAAC,CAAC;IAClF,MAAM,MAAM,GAAG,MAAM,SAAS,CAAC,QAAQ,EAAE,CAAC;IAE1C,OAAO,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;IAErC,MAAM,SAAS,CAAC,KAAK,EAAE,CAAC;IACxB,MAAM,OAAO,CAAC,KAAK,EAAE,CAAC;IACtB,MAAM,MAAM,CAAC,KAAK,EAAE,CAAC;AACvB,CAAC;AAED,IAAI,EAAE,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC"} \ No newline at end of file diff --git a/examples/tokenFederation/m2mFederation.ts b/examples/tokenFederation/m2mFederation.ts new file mode 100644 index 00000000..e4c22f4f --- /dev/null +++ b/examples/tokenFederation/m2mFederation.ts @@ -0,0 +1,65 @@ +/** + * Example: Machine-to-Machine (M2M) Token Federation with Service Principal + * + * This example demonstrates how to use token federation with a service + * principal or machine identity. This is useful for server-to-server + * authentication where there is no interactive user. + * + * When using M2M federation, you typically need to provide a client_id + * to identify the service principal to Databricks. + */ + +import { DBSQLClient } from '@databricks/sql'; + +// Example: Fetch a service account token from your identity provider +async function getServiceAccountToken(): Promise { + // Example for Azure service principal: + // + // import { ClientSecretCredential } from '@azure/identity'; + // const credential = new ClientSecretCredential( + // process.env.AZURE_TENANT_ID!, + // process.env.AZURE_CLIENT_ID!, + // process.env.AZURE_CLIENT_SECRET! + // ); + // const token = await credential.getToken('https://your-scope/.default'); + // return token.token; + + // For this example, we use an environment variable + const token = process.env.SERVICE_ACCOUNT_TOKEN!; + console.log('Fetched service account token'); + return token; +} + +async function main() { + const host = process.env.DATABRICKS_HOST!; + const path = process.env.DATABRICKS_HTTP_PATH!; + const clientId = process.env.DATABRICKS_CLIENT_ID!; + + const client = new DBSQLClient(); + + // Connect using M2M token federation + // The federationClientId identifies your service principal to Databricks + await client.connect({ + host, + path, + authType: 'external-token', + getToken: getServiceAccountToken, + enableTokenFederation: true, + federationClientId: clientId, // Required for M2M/SP federation + }); + + console.log('Connected successfully with M2M token federation'); + + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT current_user() AS user'); + const result = await operation.fetchAll(); + + console.log('Query result:', result); + + await operation.close(); + await session.close(); + await client.close(); +} + +main().catch(console.error); diff --git a/examples/tokenFederation/staticToken.d.ts b/examples/tokenFederation/staticToken.d.ts new file mode 100644 index 00000000..9e63ffd3 --- /dev/null +++ b/examples/tokenFederation/staticToken.d.ts @@ -0,0 +1,8 @@ +/** + * Example: Using a static token with the token provider system + * + * This example demonstrates how to use a static access token with the + * token provider infrastructure. This is useful when you have a token + * that doesn't change during the lifetime of your application. + */ +export {}; diff --git a/examples/tokenFederation/staticToken.js b/examples/tokenFederation/staticToken.js new file mode 100644 index 00000000..109a03c0 --- /dev/null +++ b/examples/tokenFederation/staticToken.js @@ -0,0 +1,34 @@ +'use strict'; +/** + * Example: Using a static token with the token provider system + * + * This example demonstrates how to use a static access token with the + * token provider infrastructure. This is useful when you have a token + * that doesn't change during the lifetime of your application. + */ +Object.defineProperty(exports, '__esModule', { value: true }); +const sql_1 = require('@databricks/sql'); +async function main() { + const host = process.env.DATABRICKS_HOST; + const path = process.env.DATABRICKS_HTTP_PATH; + const token = process.env.DATABRICKS_TOKEN; + const client = new sql_1.DBSQLClient(); + // Connect using a static token + await client.connect({ + host, + path, + authType: 'static-token', + staticToken: token, + }); + console.log('Connected successfully with static token'); + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT 1 AS result'); + const result = await operation.fetchAll(); + console.log('Query result:', result); + await operation.close(); + await session.close(); + await client.close(); +} +main().catch(console.error); +//# sourceMappingURL=staticToken.js.map diff --git a/examples/tokenFederation/staticToken.js.map b/examples/tokenFederation/staticToken.js.map new file mode 100644 index 00000000..92990151 --- /dev/null +++ b/examples/tokenFederation/staticToken.js.map @@ -0,0 +1 @@ +{"version":3,"file":"staticToken.js","sourceRoot":"","sources":["staticToken.ts"],"names":[],"mappings":";AAAA;;;;;;GAMG;;AAEH,yCAA8C;AAE9C,KAAK,UAAU,IAAI;IACjB,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,eAAgB,CAAC;IAC1C,MAAM,IAAI,GAAG,OAAO,CAAC,GAAG,CAAC,oBAAqB,CAAC;IAC/C,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,gBAAiB,CAAC;IAE5C,MAAM,MAAM,GAAG,IAAI,iBAAW,EAAE,CAAC;IAEjC,+BAA+B;IAC/B,MAAM,MAAM,CAAC,OAAO,CAAC;QACnB,IAAI;QACJ,IAAI;QACJ,QAAQ,EAAE,cAAc;QACxB,WAAW,EAAE,KAAK;KACnB,CAAC,CAAC;IAEH,OAAO,CAAC,GAAG,CAAC,0CAA0C,CAAC,CAAC;IAExD,iCAAiC;IACjC,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,WAAW,EAAE,CAAC;IAC3C,MAAM,SAAS,GAAG,MAAM,OAAO,CAAC,gBAAgB,CAAC,oBAAoB,CAAC,CAAC;IACvE,MAAM,MAAM,GAAG,MAAM,SAAS,CAAC,QAAQ,EAAE,CAAC;IAE1C,OAAO,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;IAErC,MAAM,SAAS,CAAC,KAAK,EAAE,CAAC;IACxB,MAAM,OAAO,CAAC,KAAK,EAAE,CAAC;IACtB,MAAM,MAAM,CAAC,KAAK,EAAE,CAAC;AACvB,CAAC;AAED,IAAI,EAAE,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC"} \ No newline at end of file diff --git a/examples/tokenFederation/staticToken.ts b/examples/tokenFederation/staticToken.ts new file mode 100644 index 00000000..d6cec8df --- /dev/null +++ b/examples/tokenFederation/staticToken.ts @@ -0,0 +1,40 @@ +/** + * Example: Using a static token with the token provider system + * + * This example demonstrates how to use a static access token with the + * token provider infrastructure. This is useful when you have a token + * that doesn't change during the lifetime of your application. + */ + +import { DBSQLClient } from '@databricks/sql'; + +async function main() { + const host = process.env.DATABRICKS_HOST!; + const path = process.env.DATABRICKS_HTTP_PATH!; + const token = process.env.DATABRICKS_TOKEN!; + + const client = new DBSQLClient(); + + // Connect using a static token + await client.connect({ + host, + path, + authType: 'static-token', + staticToken: token, + }); + + console.log('Connected successfully with static token'); + + // Open a session and run a query + const session = await client.openSession(); + const operation = await session.executeStatement('SELECT 1 AS result'); + const result = await operation.fetchAll(); + + console.log('Query result:', result); + + await operation.close(); + await session.close(); + await client.close(); +} + +main().catch(console.error); diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 00496463..25609efe 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -19,6 +19,14 @@ import HiveDriverError from './errors/HiveDriverError'; import { buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth'; +import { + TokenProviderAuthenticator, + StaticTokenProvider, + ExternalTokenProvider, + CachedTokenProvider, + FederationProvider, + ITokenProvider, +} from './connection/auth/tokenProvider'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; import CloseableCollection from './utils/CloseableCollection'; @@ -143,10 +151,63 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I }); case 'custom': return options.provider; + case 'token-provider': + return new TokenProviderAuthenticator( + this.wrapTokenProvider( + options.tokenProvider, + options.host, + options.enableTokenFederation, + options.federationClientId, + ), + this, + ); + case 'external-token': + return new TokenProviderAuthenticator( + this.wrapTokenProvider( + new ExternalTokenProvider(options.getToken), + options.host, + options.enableTokenFederation, + options.federationClientId, + ), + this, + ); + case 'static-token': + return new TokenProviderAuthenticator( + this.wrapTokenProvider( + StaticTokenProvider.fromJWT(options.staticToken), + options.host, + options.enableTokenFederation, + options.federationClientId, + ), + this, + ); // no default } } + /** + * Wraps a token provider with caching and optional federation. + * Caching is always enabled by default. Federation is opt-in. + */ + private wrapTokenProvider( + provider: ITokenProvider, + host: string, + enableFederation?: boolean, + federationClientId?: string, + ): ITokenProvider { + // Always wrap with caching first + let wrapped: ITokenProvider = new CachedTokenProvider(provider); + + // Optionally wrap with federation + if (enableFederation) { + wrapped = new FederationProvider(wrapped, host, { + clientId: federationClientId, + }); + } + + return wrapped; + } + private createConnectionProvider(options: ConnectionOptions): IConnectionProvider { return new HttpConnection(this.getConnectionOptions(options), this); } diff --git a/lib/connection/auth/tokenProvider/CachedTokenProvider.ts b/lib/connection/auth/tokenProvider/CachedTokenProvider.ts new file mode 100644 index 00000000..7172ea0b --- /dev/null +++ b/lib/connection/auth/tokenProvider/CachedTokenProvider.ts @@ -0,0 +1,98 @@ +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; + +/** + * Default refresh threshold in milliseconds (5 minutes). + * Tokens will be refreshed when they are within this threshold of expiring. + */ +const DEFAULT_REFRESH_THRESHOLD_MS = 5 * 60 * 1000; + +/** + * A token provider that wraps another provider with automatic caching. + * Tokens are cached and reused until they are close to expiring. + */ +export default class CachedTokenProvider implements ITokenProvider { + private readonly baseProvider: ITokenProvider; + + private readonly refreshThresholdMs: number; + + private cache: Token | null = null; + + private refreshPromise: Promise | null = null; + + /** + * Creates a new CachedTokenProvider. + * @param baseProvider - The underlying token provider to cache + * @param options - Optional configuration + * @param options.refreshThresholdMs - Refresh tokens this many ms before expiry (default: 5 minutes) + */ + constructor( + baseProvider: ITokenProvider, + options?: { + refreshThresholdMs?: number; + }, + ) { + this.baseProvider = baseProvider; + this.refreshThresholdMs = options?.refreshThresholdMs ?? DEFAULT_REFRESH_THRESHOLD_MS; + } + + async getToken(): Promise { + // Return cached token if it's still valid + if (this.cache && !this.shouldRefresh(this.cache)) { + return this.cache; + } + + // If already refreshing, wait for that to complete + if (this.refreshPromise) { + return this.refreshPromise; + } + + // Start refresh + this.refreshPromise = this.refreshToken(); + + try { + const token = await this.refreshPromise; + return token; + } finally { + this.refreshPromise = null; + } + } + + getName(): string { + return `cached[${this.baseProvider.getName()}]`; + } + + /** + * Clears the cached token, forcing a refresh on the next getToken() call. + */ + clearCache(): void { + this.cache = null; + } + + /** + * Determines if the token should be refreshed. + * @param token - The token to check + * @returns true if the token should be refreshed + */ + private shouldRefresh(token: Token): boolean { + // If no expiration is known, don't refresh proactively + if (!token.expiresAt) { + return false; + } + + const now = Date.now(); + const expiresAtMs = token.expiresAt.getTime(); + const refreshAtMs = expiresAtMs - this.refreshThresholdMs; + + return now >= refreshAtMs; + } + + /** + * Fetches a new token from the base provider and caches it. + */ + private async refreshToken(): Promise { + const token = await this.baseProvider.getToken(); + this.cache = token; + return token; + } +} diff --git a/lib/connection/auth/tokenProvider/ExternalTokenProvider.ts b/lib/connection/auth/tokenProvider/ExternalTokenProvider.ts new file mode 100644 index 00000000..ada48038 --- /dev/null +++ b/lib/connection/auth/tokenProvider/ExternalTokenProvider.ts @@ -0,0 +1,52 @@ +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; + +/** + * Type for the callback function that retrieves tokens from external sources. + */ +export type TokenCallback = () => Promise; + +/** + * A token provider that delegates token retrieval to an external callback function. + * Useful for integrating with secret managers, vaults, or other token sources. + */ +export default class ExternalTokenProvider implements ITokenProvider { + private readonly getTokenCallback: TokenCallback; + + private readonly parseJWT: boolean; + + private readonly providerName: string; + + /** + * Creates a new ExternalTokenProvider. + * @param getToken - Callback function that returns the access token string + * @param options - Optional configuration + * @param options.parseJWT - If true, attempt to extract expiration from JWT payload (default: true) + * @param options.name - Custom name for this provider (default: "ExternalTokenProvider") + */ + constructor( + getToken: TokenCallback, + options?: { + parseJWT?: boolean; + name?: string; + }, + ) { + this.getTokenCallback = getToken; + this.parseJWT = options?.parseJWT ?? true; + this.providerName = options?.name ?? 'ExternalTokenProvider'; + } + + async getToken(): Promise { + const accessToken = await this.getTokenCallback(); + + if (this.parseJWT) { + return Token.fromJWT(accessToken); + } + + return new Token(accessToken); + } + + getName(): string { + return this.providerName; + } +} diff --git a/lib/connection/auth/tokenProvider/FederationProvider.ts b/lib/connection/auth/tokenProvider/FederationProvider.ts new file mode 100644 index 00000000..e95b415e --- /dev/null +++ b/lib/connection/auth/tokenProvider/FederationProvider.ts @@ -0,0 +1,192 @@ +import fetch from 'node-fetch'; +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; +import { getJWTIssuer, isSameHost } from './utils'; + +/** + * Token exchange endpoint path for Databricks OIDC. + */ +const TOKEN_EXCHANGE_ENDPOINT = '/oidc/v1/token'; + +/** + * Grant type for RFC 8693 token exchange. + */ +const TOKEN_EXCHANGE_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange'; + +/** + * Subject token type for JWT tokens. + */ +const SUBJECT_TOKEN_TYPE = 'urn:ietf:params:oauth:token-type:jwt'; + +/** + * Default scope for SQL operations. + */ +const DEFAULT_SCOPE = 'sql'; + +/** + * Timeout for token exchange requests in milliseconds. + */ +const REQUEST_TIMEOUT_MS = 30000; + +/** + * A token provider that wraps another provider with automatic token federation. + * When the base provider returns a token from a different issuer, this provider + * exchanges it for a Databricks-compatible token using RFC 8693. + */ +export default class FederationProvider implements ITokenProvider { + private readonly baseProvider: ITokenProvider; + + private readonly databricksHost: string; + + private readonly clientId?: string; + + private readonly returnOriginalTokenOnFailure: boolean; + + /** + * Creates a new FederationProvider. + * @param baseProvider - The underlying token provider + * @param databricksHost - The Databricks workspace host URL + * @param options - Optional configuration + * @param options.clientId - Client ID for M2M/service principal federation + * @param options.returnOriginalTokenOnFailure - Return original token if exchange fails (default: true) + */ + constructor( + baseProvider: ITokenProvider, + databricksHost: string, + options?: { + clientId?: string; + returnOriginalTokenOnFailure?: boolean; + }, + ) { + this.baseProvider = baseProvider; + this.databricksHost = databricksHost; + this.clientId = options?.clientId; + this.returnOriginalTokenOnFailure = options?.returnOriginalTokenOnFailure ?? true; + } + + async getToken(): Promise { + const token = await this.baseProvider.getToken(); + + // Check if token needs exchange + if (!this.needsTokenExchange(token)) { + return token; + } + + // Attempt token exchange + try { + return await this.exchangeToken(token); + } catch (error) { + if (this.returnOriginalTokenOnFailure) { + // Fall back to original token + return token; + } + throw error; + } + } + + getName(): string { + return `federated[${this.baseProvider.getName()}]`; + } + + /** + * Determines if the token needs to be exchanged. + * @param token - The token to check + * @returns true if the token should be exchanged + */ + private needsTokenExchange(token: Token): boolean { + const issuer = getJWTIssuer(token.accessToken); + + // If we can't extract the issuer, don't exchange (might not be a JWT) + if (!issuer) { + return false; + } + + // If the issuer is the same as Databricks host, no exchange needed + if (isSameHost(issuer, this.databricksHost)) { + return false; + } + + return true; + } + + /** + * Exchanges the token for a Databricks-compatible token using RFC 8693. + * @param token - The token to exchange + * @returns The exchanged token + */ + private async exchangeToken(token: Token): Promise { + const url = this.buildExchangeUrl(); + + const params = new URLSearchParams({ + grant_type: TOKEN_EXCHANGE_GRANT_TYPE, + subject_token_type: SUBJECT_TOKEN_TYPE, + subject_token: token.accessToken, + scope: DEFAULT_SCOPE, + }); + + if (this.clientId) { + params.append('client_id', this.clientId); + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), REQUEST_TIMEOUT_MS); + + try { + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: params.toString(), + signal: controller.signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Token exchange failed: ${response.status} ${response.statusText} - ${errorText}`); + } + + const data = (await response.json()) as { + access_token?: string; + token_type?: string; + expires_in?: number; + }; + + if (!data.access_token) { + throw new Error('Token exchange response missing access_token'); + } + + // Calculate expiration from expires_in + let expiresAt: Date | undefined; + if (typeof data.expires_in === 'number') { + expiresAt = new Date(Date.now() + data.expires_in * 1000); + } + + return new Token(data.access_token, { + tokenType: data.token_type ?? 'Bearer', + expiresAt, + }); + } finally { + clearTimeout(timeoutId); + } + } + + /** + * Builds the token exchange URL. + */ + private buildExchangeUrl(): string { + let host = this.databricksHost; + + // Ensure host has a protocol + if (!host.includes('://')) { + host = `https://${host}`; + } + + // Remove trailing slash + if (host.endsWith('/')) { + host = host.slice(0, -1); + } + + return `${host}${TOKEN_EXCHANGE_ENDPOINT}`; + } +} diff --git a/lib/connection/auth/tokenProvider/ITokenProvider.ts b/lib/connection/auth/tokenProvider/ITokenProvider.ts new file mode 100644 index 00000000..a7cd23dc --- /dev/null +++ b/lib/connection/auth/tokenProvider/ITokenProvider.ts @@ -0,0 +1,19 @@ +import Token from './Token'; + +/** + * Interface for token providers that supply access tokens for authentication. + * Token providers can be wrapped with caching and federation decorators. + */ +export default interface ITokenProvider { + /** + * Retrieves an access token for authentication. + * @returns A Promise that resolves to a Token object containing the access token + */ + getToken(): Promise; + + /** + * Returns the name of this token provider for logging and debugging purposes. + * @returns The provider name + */ + getName(): string; +} diff --git a/lib/connection/auth/tokenProvider/StaticTokenProvider.ts b/lib/connection/auth/tokenProvider/StaticTokenProvider.ts new file mode 100644 index 00000000..0a4acead --- /dev/null +++ b/lib/connection/auth/tokenProvider/StaticTokenProvider.ts @@ -0,0 +1,58 @@ +import ITokenProvider from './ITokenProvider'; +import Token from './Token'; + +/** + * A token provider that returns a static token. + * Useful for testing or when the token is obtained through external means. + */ +export default class StaticTokenProvider implements ITokenProvider { + private readonly token: Token; + + /** + * Creates a new StaticTokenProvider. + * @param accessToken - The access token string + * @param options - Optional token configuration (tokenType, expiresAt, refreshToken, scopes) + */ + constructor( + accessToken: string, + options?: { + tokenType?: string; + expiresAt?: Date; + refreshToken?: string; + scopes?: string[]; + }, + ) { + this.token = new Token(accessToken, options); + } + + /** + * Creates a StaticTokenProvider from a JWT string. + * The expiration time will be extracted from the JWT payload. + * @param jwt - The JWT token string + * @param options - Optional token configuration + */ + static fromJWT( + jwt: string, + options?: { + tokenType?: string; + refreshToken?: string; + scopes?: string[]; + }, + ): StaticTokenProvider { + const token = Token.fromJWT(jwt, options); + return new StaticTokenProvider(token.accessToken, { + tokenType: token.tokenType, + expiresAt: token.expiresAt, + refreshToken: token.refreshToken, + scopes: token.scopes, + }); + } + + async getToken(): Promise { + return this.token; + } + + getName(): string { + return 'StaticTokenProvider'; + } +} diff --git a/lib/connection/auth/tokenProvider/Token.ts b/lib/connection/auth/tokenProvider/Token.ts new file mode 100644 index 00000000..911b2bdd --- /dev/null +++ b/lib/connection/auth/tokenProvider/Token.ts @@ -0,0 +1,151 @@ +import { HeadersInit } from 'node-fetch'; + +/** + * Safety buffer in seconds to consider a token expired before its actual expiration time. + * This prevents using tokens that are about to expire during in-flight requests. + */ +const EXPIRATION_BUFFER_SECONDS = 30; + +/** + * Represents an access token with optional metadata and lifecycle management. + */ +export default class Token { + private readonly _accessToken: string; + + private readonly _tokenType: string; + + private readonly _expiresAt?: Date; + + private readonly _refreshToken?: string; + + private readonly _scopes?: string[]; + + constructor( + accessToken: string, + options?: { + tokenType?: string; + expiresAt?: Date; + refreshToken?: string; + scopes?: string[]; + }, + ) { + this._accessToken = accessToken; + this._tokenType = options?.tokenType ?? 'Bearer'; + this._expiresAt = options?.expiresAt; + this._refreshToken = options?.refreshToken; + this._scopes = options?.scopes; + } + + /** + * The access token string. + */ + get accessToken(): string { + return this._accessToken; + } + + /** + * The token type (e.g., "Bearer"). + */ + get tokenType(): string { + return this._tokenType; + } + + /** + * The expiration time of the token, if known. + */ + get expiresAt(): Date | undefined { + return this._expiresAt; + } + + /** + * The refresh token, if available. + */ + get refreshToken(): string | undefined { + return this._refreshToken; + } + + /** + * The scopes associated with this token. + */ + get scopes(): string[] | undefined { + return this._scopes; + } + + /** + * Checks if the token has expired, including a safety buffer. + * Returns false if expiration time is unknown. + */ + isExpired(): boolean { + if (!this._expiresAt) { + return false; + } + const now = new Date(); + const bufferMs = EXPIRATION_BUFFER_SECONDS * 1000; + return this._expiresAt.getTime() - bufferMs <= now.getTime(); + } + + /** + * Sets the Authorization header on the provided headers object. + * @param headers - The headers object to modify + * @returns The modified headers object with Authorization set + */ + setAuthHeader(headers: HeadersInit): HeadersInit { + return { + ...headers, + Authorization: `${this._tokenType} ${this._accessToken}`, + }; + } + + /** + * Creates a Token from a JWT string, extracting the expiration time from the payload. + * If the JWT cannot be decoded, the token is created without expiration info. + * The server will validate the token anyway, so decoding failures are handled gracefully. + * @param jwt - The JWT token string + * @param options - Additional token options (tokenType, refreshToken, scopes) + * @returns A new Token instance with expiration extracted from the JWT (if available) + */ + static fromJWT( + jwt: string, + options?: { + tokenType?: string; + refreshToken?: string; + scopes?: string[]; + }, + ): Token { + let expiresAt: Date | undefined; + + try { + const parts = jwt.split('.'); + if (parts.length >= 2) { + const payload = Buffer.from(parts[1], 'base64').toString('utf8'); + const decoded = JSON.parse(payload); + if (typeof decoded.exp === 'number') { + expiresAt = new Date(decoded.exp * 1000); + } + } + } catch { + // If we can't decode the JWT, we'll proceed without expiration info + // The server will validate the token anyway + } + + return new Token(jwt, { + tokenType: options?.tokenType, + expiresAt, + refreshToken: options?.refreshToken, + scopes: options?.scopes, + }); + } + + /** + * Converts the token to a plain object for serialization. + */ + toJSON(): Record { + return { + accessToken: this._accessToken, + tokenType: this._tokenType, + expiresAt: this._expiresAt?.toISOString(), + refreshToken: this._refreshToken, + scopes: this._scopes, + }; + } +} diff --git a/lib/connection/auth/tokenProvider/TokenProviderAuthenticator.ts b/lib/connection/auth/tokenProvider/TokenProviderAuthenticator.ts new file mode 100644 index 00000000..2c77127b --- /dev/null +++ b/lib/connection/auth/tokenProvider/TokenProviderAuthenticator.ts @@ -0,0 +1,44 @@ +import { HeadersInit } from 'node-fetch'; +import IAuthentication from '../../contracts/IAuthentication'; +import ITokenProvider from './ITokenProvider'; +import IClientContext from '../../../contracts/IClientContext'; +import { LogLevel } from '../../../contracts/IDBSQLLogger'; + +/** + * Adapts an ITokenProvider to the IAuthentication interface used by the driver. + * This allows token providers to be used with the existing authentication system. + */ +export default class TokenProviderAuthenticator implements IAuthentication { + private readonly tokenProvider: ITokenProvider; + + private readonly context: IClientContext; + + private readonly headers: HeadersInit; + + /** + * Creates a new TokenProviderAuthenticator. + * @param tokenProvider - The token provider to use for authentication + * @param context - The client context for logging + * @param headers - Additional headers to include with each request + */ + constructor(tokenProvider: ITokenProvider, context: IClientContext, headers?: HeadersInit) { + this.tokenProvider = tokenProvider; + this.context = context; + this.headers = headers ?? {}; + } + + async authenticate(): Promise { + const logger = this.context.getLogger(); + const providerName = this.tokenProvider.getName(); + + logger.log(LogLevel.debug, `TokenProviderAuthenticator: getting token from ${providerName}`); + + const token = await this.tokenProvider.getToken(); + + if (token.isExpired()) { + logger.log(LogLevel.warn, `TokenProviderAuthenticator: token from ${providerName} is expired`); + } + + return token.setAuthHeader(this.headers); + } +} diff --git a/lib/connection/auth/tokenProvider/index.ts b/lib/connection/auth/tokenProvider/index.ts new file mode 100644 index 00000000..e09db00f --- /dev/null +++ b/lib/connection/auth/tokenProvider/index.ts @@ -0,0 +1,8 @@ +export { default as ITokenProvider } from './ITokenProvider'; +export { default as Token } from './Token'; +export { default as StaticTokenProvider } from './StaticTokenProvider'; +export { default as ExternalTokenProvider, TokenCallback } from './ExternalTokenProvider'; +export { default as TokenProviderAuthenticator } from './TokenProviderAuthenticator'; +export { default as CachedTokenProvider } from './CachedTokenProvider'; +export { default as FederationProvider } from './FederationProvider'; +export { decodeJWT, getJWTIssuer, isSameHost } from './utils'; diff --git a/lib/connection/auth/tokenProvider/utils.ts b/lib/connection/auth/tokenProvider/utils.ts new file mode 100644 index 00000000..cc8df0e2 --- /dev/null +++ b/lib/connection/auth/tokenProvider/utils.ts @@ -0,0 +1,79 @@ +/** + * Decodes a JWT token without verifying the signature. + * This is safe because the server will validate the token anyway. + * + * @param token - The JWT token string + * @returns The decoded payload as a record, or null if decoding fails + */ +export function decodeJWT(token: string): Record | null { + try { + const parts = token.split('.'); + if (parts.length < 2) { + return null; + } + const payload = Buffer.from(parts[1], 'base64').toString('utf8'); + return JSON.parse(payload); + } catch { + return null; + } +} + +/** + * Extracts the issuer from a JWT token. + * + * @param token - The JWT token string + * @returns The issuer string, or null if not found + */ +export function getJWTIssuer(token: string): string | null { + const payload = decodeJWT(token); + if (!payload || typeof payload.iss !== 'string') { + return null; + } + return payload.iss; +} + +/** + * Extracts the hostname from a URL or hostname string. + * Handles both full URLs and bare hostnames. + * + * @param urlOrHostname - A URL or hostname string + * @returns The extracted hostname + */ +function extractHostname(urlOrHostname: string): string { + // If it looks like a URL, parse it + if (urlOrHostname.includes('://')) { + const url = new URL(urlOrHostname); + return url.hostname; + } + + // Handle hostname with port (e.g., "databricks.com:443") + const colonIndex = urlOrHostname.indexOf(':'); + if (colonIndex !== -1) { + return urlOrHostname.substring(0, colonIndex); + } + + // Bare hostname + return urlOrHostname; +} + +/** + * Compares two host URLs, ignoring ports. + * Treats "databricks.com" and "databricks.com:443" as equivalent. + * + * @param url1 - First URL or hostname + * @param url2 - Second URL or hostname + * @returns true if the hosts are the same + */ +export function isSameHost(url1: string, url2: string): boolean { + try { + const host1 = extractHostname(url1); + const host2 = extractHostname(url2); + // Empty hostnames are not valid + if (!host1 || !host2) { + return false; + } + return host1.toLowerCase() === host2.toLowerCase(); + } catch { + return false; + } +} diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 26588031..4b2f39a4 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -3,6 +3,8 @@ import IDBSQLSession from './IDBSQLSession'; import IAuthentication from '../connection/contracts/IAuthentication'; import { ProxyOptions } from '../connection/contracts/IConnectionOptions'; import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence'; +import ITokenProvider from '../connection/auth/tokenProvider/ITokenProvider'; +import { TokenCallback } from '../connection/auth/tokenProvider/ExternalTokenProvider'; export interface ClientOptions { logger?: IDBSQLLogger; @@ -24,6 +26,24 @@ type AuthOptions = | { authType: 'custom'; provider: IAuthentication; + } + | { + authType: 'token-provider'; + tokenProvider: ITokenProvider; + enableTokenFederation?: boolean; + federationClientId?: string; + } + | { + authType: 'external-token'; + getToken: TokenCallback; + enableTokenFederation?: boolean; + federationClientId?: string; + } + | { + authType: 'static-token'; + staticToken: string; + enableTokenFederation?: boolean; + federationClientId?: string; }; export type ConnectionOptions = { diff --git a/lib/index.ts b/lib/index.ts index 710a036d..adf14f36 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -9,12 +9,28 @@ import DBSQLSession from './DBSQLSession'; import { DBSQLParameter, DBSQLParameterType } from './DBSQLParameter'; import DBSQLLogger from './DBSQLLogger'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; +import { + Token, + StaticTokenProvider, + ExternalTokenProvider, + CachedTokenProvider, + FederationProvider, +} from './connection/auth/tokenProvider'; import HttpConnection from './connection/connections/HttpConnection'; import { formatProgress } from './utils'; import { LogLevel } from './contracts/IDBSQLLogger'; +// Re-export types for TypeScript users +export type { default as ITokenProvider } from './connection/auth/tokenProvider/ITokenProvider'; + export const auth = { PlainHttpAuthentication, + // Token provider classes for custom authentication + Token, + StaticTokenProvider, + ExternalTokenProvider, + CachedTokenProvider, + FederationProvider, }; const { TException, TApplicationException, TApplicationExceptionType, TProtocolException, TProtocolExceptionType } = diff --git a/tests/unit/connection/auth/tokenProvider/CachedTokenProvider.test.ts b/tests/unit/connection/auth/tokenProvider/CachedTokenProvider.test.ts new file mode 100644 index 00000000..5c62a89a --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/CachedTokenProvider.test.ts @@ -0,0 +1,165 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import CachedTokenProvider from '../../../../../lib/connection/auth/tokenProvider/CachedTokenProvider'; +import ITokenProvider from '../../../../../lib/connection/auth/tokenProvider/ITokenProvider'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; + +class MockTokenProvider implements ITokenProvider { + public callCount = 0; + public tokenToReturn: Token; + + constructor(expiresInMs: number = 3600000) { + this.tokenToReturn = new Token(`token-${this.callCount}`, { + expiresAt: new Date(Date.now() + expiresInMs), + }); + } + + async getToken(): Promise { + this.callCount += 1; + this.tokenToReturn = new Token(`token-${this.callCount}`, { + expiresAt: this.tokenToReturn.expiresAt, + }); + return this.tokenToReturn; + } + + getName(): string { + return 'MockTokenProvider'; + } +} + +describe('CachedTokenProvider', () => { + let clock: sinon.SinonFakeTimers; + + beforeEach(() => { + clock = sinon.useFakeTimers(Date.now()); + }); + + afterEach(() => { + clock.restore(); + }); + + describe('getToken', () => { + it('should cache tokens and return the same token on subsequent calls', async () => { + const baseProvider = new MockTokenProvider(3600000); // 1 hour expiry + const cachedProvider = new CachedTokenProvider(baseProvider); + + const token1 = await cachedProvider.getToken(); + const token2 = await cachedProvider.getToken(); + const token3 = await cachedProvider.getToken(); + + expect(token1.accessToken).to.equal(token2.accessToken); + expect(token2.accessToken).to.equal(token3.accessToken); + expect(baseProvider.callCount).to.equal(1); // Only called once + }); + + it('should refresh token when it approaches expiry', async () => { + const expiresInMs = 10 * 60 * 1000; // 10 minutes + const baseProvider = new MockTokenProvider(expiresInMs); + const cachedProvider = new CachedTokenProvider(baseProvider, { + refreshThresholdMs: 5 * 60 * 1000, // 5 minutes threshold + }); + + const token1 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); + + // Advance time to 6 minutes from now (within refresh threshold) + clock.tick(6 * 60 * 1000); + + const token2 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(2); // Should have refreshed + expect(token1.accessToken).to.not.equal(token2.accessToken); + }); + + it('should not refresh token when not within threshold', async () => { + const expiresInMs = 60 * 60 * 1000; // 1 hour + const baseProvider = new MockTokenProvider(expiresInMs); + const cachedProvider = new CachedTokenProvider(baseProvider, { + refreshThresholdMs: 5 * 60 * 1000, // 5 minutes threshold + }); + + await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); + + // Advance time by 10 minutes (still 50 minutes until expiry) + clock.tick(10 * 60 * 1000); + + await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); // Should still use cached + }); + + it('should handle tokens without expiration', async () => { + const baseProvider: ITokenProvider = { + async getToken() { + return new Token('no-expiry-token'); + }, + getName() { + return 'NoExpiryProvider'; + }, + }; + const getTokenSpy = sinon.spy(baseProvider, 'getToken'); + const cachedProvider = new CachedTokenProvider(baseProvider); + + await cachedProvider.getToken(); + await cachedProvider.getToken(); + await cachedProvider.getToken(); + + expect(getTokenSpy.callCount).to.equal(1); // Should cache indefinitely + }); + + it('should handle concurrent getToken calls', async () => { + let resolvePromise: (token: Token) => void; + const slowProvider: ITokenProvider = { + getToken() { + return new Promise((resolve) => { + resolvePromise = resolve; + }); + }, + getName() { + return 'SlowProvider'; + }, + }; + const getTokenSpy = sinon.spy(slowProvider, 'getToken'); + const cachedProvider = new CachedTokenProvider(slowProvider); + + // Start multiple concurrent requests + const promise1 = cachedProvider.getToken(); + const promise2 = cachedProvider.getToken(); + const promise3 = cachedProvider.getToken(); + + // Resolve the single underlying request + resolvePromise!(new Token('concurrent-token')); + + const [token1, token2, token3] = await Promise.all([promise1, promise2, promise3]); + + expect(token1.accessToken).to.equal('concurrent-token'); + expect(token2.accessToken).to.equal('concurrent-token'); + expect(token3.accessToken).to.equal('concurrent-token'); + expect(getTokenSpy.callCount).to.equal(1); // Only one underlying call + }); + }); + + describe('clearCache', () => { + it('should force a refresh on the next getToken call', async () => { + const baseProvider = new MockTokenProvider(3600000); + const cachedProvider = new CachedTokenProvider(baseProvider); + + const token1 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(1); + + cachedProvider.clearCache(); + + const token2 = await cachedProvider.getToken(); + expect(baseProvider.callCount).to.equal(2); + expect(token1.accessToken).to.not.equal(token2.accessToken); + }); + }); + + describe('getName', () => { + it('should return wrapped name', () => { + const baseProvider = new MockTokenProvider(); + const cachedProvider = new CachedTokenProvider(baseProvider); + + expect(cachedProvider.getName()).to.equal('cached[MockTokenProvider]'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/ExternalTokenProvider.test.ts b/tests/unit/connection/auth/tokenProvider/ExternalTokenProvider.test.ts new file mode 100644 index 00000000..6695040d --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/ExternalTokenProvider.test.ts @@ -0,0 +1,108 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import ExternalTokenProvider from '../../../../../lib/connection/auth/tokenProvider/ExternalTokenProvider'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('ExternalTokenProvider', () => { + describe('constructor', () => { + it('should create provider with callback', async () => { + const callback = sinon.stub().resolves('my-token'); + const provider = new ExternalTokenProvider(callback); + + await provider.getToken(); + + expect(callback.calledOnce).to.be.true; + }); + + it('should use default name', () => { + const provider = new ExternalTokenProvider(async () => 'token'); + expect(provider.getName()).to.equal('ExternalTokenProvider'); + }); + + it('should use custom name', () => { + const provider = new ExternalTokenProvider(async () => 'token', { name: 'MyCustomProvider' }); + expect(provider.getName()).to.equal('MyCustomProvider'); + }); + }); + + describe('getToken', () => { + it('should call callback and return token', async () => { + const callback = sinon.stub().resolves('my-access-token'); + const provider = new ExternalTokenProvider(callback); + + const token = await provider.getToken(); + + expect(token.accessToken).to.equal('my-access-token'); + expect(token.tokenType).to.equal('Bearer'); + }); + + it('should extract expiration from JWT by default', async () => { + const exp = Math.floor(Date.now() / 1000) + 3600; + const jwt = createJWT({ exp, iss: 'test-issuer' }); + const callback = sinon.stub().resolves(jwt); + const provider = new ExternalTokenProvider(callback); + + const token = await provider.getToken(); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.instanceOf(Date); + expect(Math.floor(token.expiresAt!.getTime() / 1000)).to.equal(exp); + }); + + it('should not parse JWT when parseJWT is false', async () => { + const jwt = createJWT({ exp: Math.floor(Date.now() / 1000) + 3600 }); + const callback = sinon.stub().resolves(jwt); + const provider = new ExternalTokenProvider(callback, { parseJWT: false }); + + const token = await provider.getToken(); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.undefined; + }); + + it('should call callback on each getToken call', async () => { + let callCount = 0; + const callback = async () => { + callCount += 1; + return `token-${callCount}`; + }; + const provider = new ExternalTokenProvider(callback); + + const token1 = await provider.getToken(); + const token2 = await provider.getToken(); + + expect(token1.accessToken).to.equal('token-1'); + expect(token2.accessToken).to.equal('token-2'); + }); + + it('should propagate errors from callback', async () => { + const error = new Error('Failed to get token'); + const callback = sinon.stub().rejects(error); + const provider = new ExternalTokenProvider(callback); + + try { + await provider.getToken(); + expect.fail('Should have thrown an error'); + } catch (e) { + expect(e).to.equal(error); + } + }); + }); + + describe('getName', () => { + it('should return default name', () => { + const provider = new ExternalTokenProvider(async () => 'token'); + expect(provider.getName()).to.equal('ExternalTokenProvider'); + }); + + it('should return custom name', () => { + const provider = new ExternalTokenProvider(async () => 'token', { name: 'VaultTokenProvider' }); + expect(provider.getName()).to.equal('VaultTokenProvider'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/FederationProvider.test.ts b/tests/unit/connection/auth/tokenProvider/FederationProvider.test.ts new file mode 100644 index 00000000..4a7c5465 --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/FederationProvider.test.ts @@ -0,0 +1,79 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import FederationProvider from '../../../../../lib/connection/auth/tokenProvider/FederationProvider'; +import ITokenProvider from '../../../../../lib/connection/auth/tokenProvider/ITokenProvider'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +class MockTokenProvider implements ITokenProvider { + public tokenToReturn: Token; + + constructor(accessToken: string) { + this.tokenToReturn = new Token(accessToken); + } + + async getToken(): Promise { + return this.tokenToReturn; + } + + getName(): string { + return 'MockTokenProvider'; + } +} + +describe('FederationProvider', () => { + describe('getToken', () => { + it('should pass through token if issuer matches Databricks host', async () => { + const jwt = createJWT({ iss: 'https://my-workspace.cloud.databricks.com' }); + const baseProvider = new MockTokenProvider(jwt); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal(jwt); + }); + + it('should pass through non-JWT tokens', async () => { + const baseProvider = new MockTokenProvider('not-a-jwt-token'); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal('not-a-jwt-token'); + }); + + it('should pass through token when issuer matches (case insensitive)', async () => { + const jwt = createJWT({ iss: 'https://MY-WORKSPACE.CLOUD.DATABRICKS.COM' }); + const baseProvider = new MockTokenProvider(jwt); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal(jwt); + }); + + it('should pass through token when issuer matches (ignoring port)', async () => { + const jwt = createJWT({ iss: 'https://my-workspace.cloud.databricks.com:443' }); + const baseProvider = new MockTokenProvider(jwt); + const federationProvider = new FederationProvider(baseProvider, 'my-workspace.cloud.databricks.com'); + + const token = await federationProvider.getToken(); + + expect(token.accessToken).to.equal(jwt); + }); + }); + + describe('getName', () => { + it('should return wrapped name', () => { + const baseProvider = new MockTokenProvider('token'); + const federationProvider = new FederationProvider(baseProvider, 'host.com'); + + expect(federationProvider.getName()).to.equal('federated[MockTokenProvider]'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/StaticTokenProvider.test.ts b/tests/unit/connection/auth/tokenProvider/StaticTokenProvider.test.ts new file mode 100644 index 00000000..976bf84e --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/StaticTokenProvider.test.ts @@ -0,0 +1,85 @@ +import { expect } from 'chai'; +import StaticTokenProvider from '../../../../../lib/connection/auth/tokenProvider/StaticTokenProvider'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('StaticTokenProvider', () => { + describe('constructor', () => { + it('should create provider with access token only', async () => { + const provider = new StaticTokenProvider('my-access-token'); + const token = await provider.getToken(); + + expect(token.accessToken).to.equal('my-access-token'); + expect(token.tokenType).to.equal('Bearer'); + }); + + it('should create provider with custom options', async () => { + const expiresAt = new Date('2025-01-01T00:00:00Z'); + const provider = new StaticTokenProvider('my-access-token', { + tokenType: 'CustomType', + expiresAt, + refreshToken: 'refresh-token', + scopes: ['read', 'write'], + }); + const token = await provider.getToken(); + + expect(token.accessToken).to.equal('my-access-token'); + expect(token.tokenType).to.equal('CustomType'); + expect(token.expiresAt).to.deep.equal(expiresAt); + expect(token.refreshToken).to.equal('refresh-token'); + expect(token.scopes).to.deep.equal(['read', 'write']); + }); + }); + + describe('fromJWT', () => { + it('should create provider from JWT and extract expiration', async () => { + const exp = Math.floor(Date.now() / 1000) + 3600; + const jwt = createJWT({ exp, iss: 'test-issuer' }); + + const provider = StaticTokenProvider.fromJWT(jwt); + const token = await provider.getToken(); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.instanceOf(Date); + expect(Math.floor(token.expiresAt!.getTime() / 1000)).to.equal(exp); + }); + + it('should create provider from JWT with custom options', async () => { + const jwt = createJWT({ exp: Math.floor(Date.now() / 1000) + 3600 }); + + const provider = StaticTokenProvider.fromJWT(jwt, { + tokenType: 'CustomType', + refreshToken: 'refresh', + scopes: ['sql'], + }); + const token = await provider.getToken(); + + expect(token.tokenType).to.equal('CustomType'); + expect(token.refreshToken).to.equal('refresh'); + expect(token.scopes).to.deep.equal(['sql']); + }); + }); + + describe('getToken', () => { + it('should always return the same token', async () => { + const provider = new StaticTokenProvider('my-token'); + + const token1 = await provider.getToken(); + const token2 = await provider.getToken(); + + expect(token1).to.equal(token2); + expect(token1.accessToken).to.equal('my-token'); + }); + }); + + describe('getName', () => { + it('should return provider name', () => { + const provider = new StaticTokenProvider('my-token'); + expect(provider.getName()).to.equal('StaticTokenProvider'); + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/Token.test.ts b/tests/unit/connection/auth/tokenProvider/Token.test.ts new file mode 100644 index 00000000..febaf712 --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/Token.test.ts @@ -0,0 +1,162 @@ +import { expect } from 'chai'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('Token', () => { + describe('constructor', () => { + it('should create token with access token only', () => { + const token = new Token('test-access-token'); + expect(token.accessToken).to.equal('test-access-token'); + expect(token.tokenType).to.equal('Bearer'); + expect(token.expiresAt).to.be.undefined; + expect(token.refreshToken).to.be.undefined; + expect(token.scopes).to.be.undefined; + }); + + it('should create token with all options', () => { + const expiresAt = new Date('2025-01-01T00:00:00Z'); + const token = new Token('test-access-token', { + tokenType: 'CustomType', + expiresAt, + refreshToken: 'refresh-token', + scopes: ['read', 'write'], + }); + expect(token.accessToken).to.equal('test-access-token'); + expect(token.tokenType).to.equal('CustomType'); + expect(token.expiresAt).to.deep.equal(expiresAt); + expect(token.refreshToken).to.equal('refresh-token'); + expect(token.scopes).to.deep.equal(['read', 'write']); + }); + }); + + describe('isExpired', () => { + it('should return false when expiration is not set', () => { + const token = new Token('test-token'); + expect(token.isExpired()).to.be.false; + }); + + it('should return true when token is expired', () => { + const expiresAt = new Date(Date.now() - 60000); // 1 minute ago + const token = new Token('test-token', { expiresAt }); + expect(token.isExpired()).to.be.true; + }); + + it('should return false when token is not expired', () => { + const expiresAt = new Date(Date.now() + 300000); // 5 minutes from now + const token = new Token('test-token', { expiresAt }); + expect(token.isExpired()).to.be.false; + }); + + it('should return true when within 30 second safety buffer', () => { + const expiresAt = new Date(Date.now() + 20000); // 20 seconds from now + const token = new Token('test-token', { expiresAt }); + expect(token.isExpired()).to.be.true; + }); + }); + + describe('setAuthHeader', () => { + it('should set Authorization header with default Bearer type', () => { + const token = new Token('my-token'); + const headers = token.setAuthHeader({}); + expect(headers).to.deep.equal({ Authorization: 'Bearer my-token' }); + }); + + it('should set Authorization header with custom type', () => { + const token = new Token('my-token', { tokenType: 'Basic' }); + const headers = token.setAuthHeader({}); + expect(headers).to.deep.equal({ Authorization: 'Basic my-token' }); + }); + + it('should preserve existing headers', () => { + const token = new Token('my-token'); + const headers = token.setAuthHeader({ 'Content-Type': 'application/json' }); + expect(headers).to.deep.equal({ + 'Content-Type': 'application/json', + Authorization: 'Bearer my-token', + }); + }); + }); + + describe('fromJWT', () => { + it('should extract expiration from JWT payload', () => { + const exp = Math.floor(Date.now() / 1000) + 3600; // 1 hour from now + const jwt = createJWT({ exp, iss: 'test-issuer' }); + const token = Token.fromJWT(jwt); + + expect(token.accessToken).to.equal(jwt); + expect(token.tokenType).to.equal('Bearer'); + expect(token.expiresAt).to.be.instanceOf(Date); + expect(Math.floor(token.expiresAt!.getTime() / 1000)).to.equal(exp); + }); + + it('should handle JWT without expiration', () => { + const jwt = createJWT({ iss: 'test-issuer' }); + const token = Token.fromJWT(jwt); + + expect(token.accessToken).to.equal(jwt); + expect(token.expiresAt).to.be.undefined; + }); + + it('should handle malformed JWT gracefully', () => { + const token = Token.fromJWT('not-a-valid-jwt'); + expect(token.accessToken).to.equal('not-a-valid-jwt'); + expect(token.expiresAt).to.be.undefined; + }); + + it('should handle JWT with invalid base64 payload', () => { + const token = Token.fromJWT('header.!!!invalid-base64!!!.signature'); + expect(token.accessToken).to.equal('header.!!!invalid-base64!!!.signature'); + expect(token.expiresAt).to.be.undefined; + }); + + it('should apply custom options', () => { + const jwt = createJWT({ exp: Math.floor(Date.now() / 1000) + 3600 }); + const token = Token.fromJWT(jwt, { + tokenType: 'CustomType', + refreshToken: 'refresh', + scopes: ['sql'], + }); + + expect(token.tokenType).to.equal('CustomType'); + expect(token.refreshToken).to.equal('refresh'); + expect(token.scopes).to.deep.equal(['sql']); + }); + }); + + describe('toJSON', () => { + it('should serialize token to JSON', () => { + const expiresAt = new Date('2025-01-01T00:00:00Z'); + const token = new Token('test-token', { + tokenType: 'Bearer', + expiresAt, + refreshToken: 'refresh', + scopes: ['read'], + }); + + const json = token.toJSON(); + expect(json).to.deep.equal({ + accessToken: 'test-token', + tokenType: 'Bearer', + expiresAt: '2025-01-01T00:00:00.000Z', + refreshToken: 'refresh', + scopes: ['read'], + }); + }); + + it('should handle undefined optional fields', () => { + const token = new Token('test-token'); + const json = token.toJSON(); + + expect(json.accessToken).to.equal('test-token'); + expect(json.tokenType).to.equal('Bearer'); + expect(json.expiresAt).to.be.undefined; + expect(json.refreshToken).to.be.undefined; + expect(json.scopes).to.be.undefined; + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/TokenProviderAuthenticator.test.ts b/tests/unit/connection/auth/tokenProvider/TokenProviderAuthenticator.test.ts new file mode 100644 index 00000000..a5a3963e --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/TokenProviderAuthenticator.test.ts @@ -0,0 +1,108 @@ +import { expect } from 'chai'; +import sinon from 'sinon'; +import TokenProviderAuthenticator from '../../../../../lib/connection/auth/tokenProvider/TokenProviderAuthenticator'; +import ITokenProvider from '../../../../../lib/connection/auth/tokenProvider/ITokenProvider'; +import Token from '../../../../../lib/connection/auth/tokenProvider/Token'; +import ClientContextStub from '../../../.stubs/ClientContextStub'; + +class MockTokenProvider implements ITokenProvider { + private token: Token; + + private name: string; + + constructor(accessToken: string, name: string = 'MockTokenProvider') { + this.token = new Token(accessToken); + this.name = name; + } + + async getToken(): Promise { + return this.token; + } + + getName(): string { + return this.name; + } + + setToken(token: Token): void { + this.token = token; + } +} + +describe('TokenProviderAuthenticator', () => { + let context: ClientContextStub; + + beforeEach(() => { + context = new ClientContextStub(); + }); + + describe('authenticate', () => { + it('should return headers with Authorization', async () => { + const provider = new MockTokenProvider('my-access-token'); + const authenticator = new TokenProviderAuthenticator(provider, context); + + const headers = await authenticator.authenticate(); + + expect(headers).to.deep.equal({ + Authorization: 'Bearer my-access-token', + }); + }); + + it('should include additional headers', async () => { + const provider = new MockTokenProvider('my-access-token'); + const authenticator = new TokenProviderAuthenticator(provider, context, { + 'Content-Type': 'application/json', + 'X-Custom-Header': 'custom-value', + }); + + const headers = await authenticator.authenticate(); + + expect(headers).to.deep.equal({ + 'Content-Type': 'application/json', + 'X-Custom-Header': 'custom-value', + Authorization: 'Bearer my-access-token', + }); + }); + + it('should use token type from token', async () => { + const provider = new MockTokenProvider('my-access-token'); + provider.setToken(new Token('my-token', { tokenType: 'Basic' })); + const authenticator = new TokenProviderAuthenticator(provider, context); + + const headers = await authenticator.authenticate(); + + expect(headers).to.deep.equal({ + Authorization: 'Basic my-token', + }); + }); + + it('should call provider getToken', async () => { + const provider = new MockTokenProvider('my-access-token'); + const getTokenSpy = sinon.spy(provider, 'getToken'); + const authenticator = new TokenProviderAuthenticator(provider, context); + + await authenticator.authenticate(); + + expect(getTokenSpy.calledOnce).to.be.true; + }); + + it('should propagate errors from provider', async () => { + const error = new Error('Failed to get token'); + const provider: ITokenProvider = { + async getToken() { + throw error; + }, + getName() { + return 'ErrorProvider'; + }, + }; + const authenticator = new TokenProviderAuthenticator(provider, context); + + try { + await authenticator.authenticate(); + expect.fail('Should have thrown an error'); + } catch (e) { + expect(e).to.equal(error); + } + }); + }); +}); diff --git a/tests/unit/connection/auth/tokenProvider/utils.test.ts b/tests/unit/connection/auth/tokenProvider/utils.test.ts new file mode 100644 index 00000000..80a91f85 --- /dev/null +++ b/tests/unit/connection/auth/tokenProvider/utils.test.ts @@ -0,0 +1,90 @@ +import { expect } from 'chai'; +import { decodeJWT, getJWTIssuer, isSameHost } from '../../../../../lib/connection/auth/tokenProvider/utils'; + +function createJWT(payload: Record): string { + const header = Buffer.from(JSON.stringify({ alg: 'HS256', typ: 'JWT' })).toString('base64'); + const body = Buffer.from(JSON.stringify(payload)).toString('base64'); + return `${header}.${body}.signature`; +} + +describe('Token Provider Utils', () => { + describe('decodeJWT', () => { + it('should decode valid JWT payload', () => { + const payload = { iss: 'test-issuer', sub: 'user123', exp: 1234567890 }; + const jwt = createJWT(payload); + + const decoded = decodeJWT(jwt); + + expect(decoded).to.deep.equal(payload); + }); + + it('should return null for malformed JWT', () => { + expect(decodeJWT('not-a-jwt')).to.be.null; + expect(decodeJWT('')).to.be.null; + }); + + it('should return null for JWT with invalid base64 payload', () => { + expect(decodeJWT('header.!!!invalid!!!.signature')).to.be.null; + }); + + it('should return null for JWT with non-JSON payload', () => { + const header = Buffer.from('{}').toString('base64'); + const body = Buffer.from('not json').toString('base64'); + expect(decodeJWT(`${header}.${body}.sig`)).to.be.null; + }); + }); + + describe('getJWTIssuer', () => { + it('should extract issuer from JWT', () => { + const jwt = createJWT({ iss: 'https://my-issuer.com', sub: 'user' }); + expect(getJWTIssuer(jwt)).to.equal('https://my-issuer.com'); + }); + + it('should return null if no issuer claim', () => { + const jwt = createJWT({ sub: 'user' }); + expect(getJWTIssuer(jwt)).to.be.null; + }); + + it('should return null if issuer is not a string', () => { + const jwt = createJWT({ iss: 123 }); + expect(getJWTIssuer(jwt)).to.be.null; + }); + + it('should return null for invalid JWT', () => { + expect(getJWTIssuer('not-a-jwt')).to.be.null; + }); + }); + + describe('isSameHost', () => { + it('should match identical hosts', () => { + expect(isSameHost('example.com', 'example.com')).to.be.true; + }); + + it('should match hosts with different protocols', () => { + expect(isSameHost('https://example.com', 'http://example.com')).to.be.true; + }); + + it('should match hosts ignoring ports', () => { + expect(isSameHost('example.com', 'example.com:443')).to.be.true; + expect(isSameHost('https://example.com:443', 'example.com')).to.be.true; + }); + + it('should match hosts case-insensitively', () => { + expect(isSameHost('Example.COM', 'example.com')).to.be.true; + }); + + it('should not match different hosts', () => { + expect(isSameHost('example.com', 'other.com')).to.be.false; + expect(isSameHost('sub.example.com', 'example.com')).to.be.false; + }); + + it('should handle full URLs', () => { + expect(isSameHost('https://my-workspace.cloud.databricks.com/path', 'my-workspace.cloud.databricks.com')).to.be + .true; + }); + + it('should return false for invalid inputs', () => { + expect(isSameHost('', '')).to.be.false; + }); + }); +}); diff --git a/tsconfig.build.json b/tsconfig.build.json index 7b375312..9aa952a0 100644 --- a/tsconfig.build.json +++ b/tsconfig.build.json @@ -4,5 +4,5 @@ "outDir": "./dist/" /* Redirect output structure to the directory. */, "rootDir": "./lib/" /* Specify the root directory of input files. Use to control the output directory structure with --outDir. */ }, - "exclude": ["./tests/**/*", "./dist/**/*"] + "exclude": ["./tests/**/*", "./dist/**/*", "./examples/**/*"] }