diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 340efd0c..b35e0d41 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -298,9 +298,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I /** * Extract the numeric workspace ID for telemetry. * - * The only reliable carrier in the connection params today is the `?o=N` - * query parameter on `httpPath` — Databricks SQL warehouses are typically - * connected to via paths like `/sql/1.0/warehouses/?o=12345678901234`. + * Two URL shapes carry the workspace ID today: + * - Warehouse, query form: `/sql/1.0/warehouses/?o=` + * - All-purpose cluster, path form: `sql/protocolv1/o//` * * Host-based extraction was tried previously but produced confidently-wrong * values: @@ -313,21 +313,84 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I * Returns `undefined` when no workspace ID can be derived. Server-side * attribution is better off seeing a missing field than a wrong value. */ - private extractWorkspaceId(): string | undefined { - const { httpPath } = this; + private static extractWorkspaceId(httpPath: string | undefined): string | undefined { if (!httpPath) { return undefined; } const queryIdx = httpPath.indexOf('?'); - if (queryIdx < 0) { - return undefined; + // Warehouse form: `?o=` in the query string. + if (queryIdx >= 0) { + const query = httpPath.slice(queryIdx + 1); + // Match `o=` as the first param, an inner `&o=`, etc. + // Workspace IDs are decimal integers; reject anything else so a stray + // `o=tenant_42` doesn't ship as a workspace ID. + const queryMatch = query.match(/(?:^|&)o=(\d+)(?:&|$)/); + if (queryMatch) { + return queryMatch[1]; + } + } + // All-purpose cluster form: `/o//` as a path segment. + const pathOnly = queryIdx >= 0 ? httpPath.slice(0, queryIdx) : httpPath; + const pathMatch = pathOnly.match(/(?:^|\/)o\/(\d+)(?:\/|$)/); + return pathMatch ? pathMatch[1] : undefined; + } + + // Detects an `o=` or `/o/` where `` is present but + // non-numeric, so the caller can warn instead of silently dropping a + // malformed workspace param. + private static hasMalformedOrgParam(httpPath: string | undefined): boolean { + if (!httpPath) { + return false; + } + const queryIdx = httpPath.indexOf('?'); + if (queryIdx >= 0) { + const query = httpPath.slice(queryIdx + 1); + const hasOrg = /(?:^|&)o=/.test(query); + const hasNumericOrg = /(?:^|&)o=\d+(?:&|$)/.test(query); + if (hasOrg && !hasNumericOrg) { + return true; + } } - const query = httpPath.slice(queryIdx + 1); - // Match `o=` as the first param, an inner `&o=`, etc. - // Workspace IDs are decimal integers; reject anything else so a stray - // `o=tenant_42` doesn't ship as a workspace ID. - const match = query.match(/(?:^|&)o=(\d+)(?:&|$)/); - return match ? match[1] : undefined; + const pathOnly = queryIdx >= 0 ? httpPath.slice(0, queryIdx) : httpPath; + const hasPathOrg = /(?:^|\/)o\/[^/]+/.test(pathOnly); + const hasNumericPathOrg = /(?:^|\/)o\/\d+(?:\/|$)/.test(pathOnly); + return hasPathOrg && !hasNumericPathOrg; + } + + /** + * Build the customHeaders map applied to telemetry POSTs and feature-flag + * GETs (SPOG / Single Panel of Glass support). When `httpPath` carries a + * workspace ID — either as a `?o=` query (warehouse) or a + * `/o//` path segment (all-purpose cluster) — endpoints + * that don't include the workspace in their URL path need it conveyed via + * the `x-databricks-org-id` header instead. A user-supplied value in + * `userHeaders` (case-insensitively keyed) wins over the parsed value. + * + * `httpPath` is passed explicitly (rather than read off `this.httpPath`) so + * the SPOG-routing dependency is visible in the signature — a future + * refactor that reorders connect() can't silently break injection. + */ + private buildCustomHeaders( + httpPath: string | undefined, + userHeaders: Record | undefined, + ): Record | undefined { + const merged: Record = { ...(userHeaders ?? {}) }; + const hasOrgIdAlready = Object.keys(merged).some((k) => k.toLowerCase() === 'x-databricks-org-id'); + if (hasOrgIdAlready) { + this.logger.log(LogLevel.debug, 'SPOG: x-databricks-org-id supplied by caller; not extracting from httpPath'); + } else { + const orgId = DBSQLClient.extractWorkspaceId(httpPath); + if (orgId) { + merged['x-databricks-org-id'] = orgId; + this.logger.log(LogLevel.debug, `SPOG: injecting x-databricks-org-id=${orgId} (extracted from httpPath)`); + } else if (DBSQLClient.hasMalformedOrgParam(httpPath)) { + this.logger.log( + LogLevel.warn, + 'SPOG: httpPath contains non-numeric workspace ID; x-databricks-org-id not injected', + ); + } + } + return Object.keys(merged).length > 0 ? merged : undefined; } /** @@ -561,6 +624,11 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I this.config.userAgentEntry = options.userAgentEntry; } + // SPOG: parse `?o=` out of httpPath and stash it as + // `x-databricks-org-id` for the telemetry + feature-flag clients, which + // hit endpoints that don't carry the workspace in their URL path. + this.config.customHeaders = this.buildCustomHeaders(options.path, options.customHeaders); + this.authProvider = this.createAuthProvider(options, authProvider); this.connectionProvider = this.createConnectionProvider(options); @@ -677,7 +745,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I safeEmit(this, (emitter) => { if (!this.host) return; const latencyMs = Date.now() - startTime; - const workspaceId = this.extractWorkspaceId(); + const workspaceId = DBSQLClient.extractWorkspaceId(this.httpPath); const driverConfig = this.driverConfigShipped ? undefined : this.buildDriverConfiguration(); if (driverConfig) { this.driverConfigShipped = true; diff --git a/lib/contracts/IClientContext.ts b/lib/contracts/IClientContext.ts index 955b9c52..43a47745 100644 --- a/lib/contracts/IClientContext.ts +++ b/lib/contracts/IClientContext.ts @@ -52,6 +52,17 @@ export interface ClientConfig { */ telemetryFlushOnExit?: boolean; userAgentEntry?: string; + + /** + * Extra HTTP headers attached to driver-owned out-of-band requests + * (telemetry, feature flags). Populated by `DBSQLClient.connect()` from + * `ConnectionOptions.customHeaders` plus an `x-databricks-org-id` header + * derived from the `?o=` query parameter on `httpPath` when present, to + * support SPOG (Single Panel of Glass) account-level routing on endpoints + * that don't carry `?o=` in their URL path. NOT applied to Thrift or + * OAuth/OIDC requests. + */ + customHeaders?: Record; } export default interface IClientContext { diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 08592219..b75a8075 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -55,6 +55,17 @@ export type ConnectionOptions = { proxy?: ProxyOptions; enableMetricViewMetadata?: boolean; + /** + * Extra HTTP headers attached to driver-owned out-of-band requests + * (telemetry POSTs and feature-flag GETs). Not applied to the primary + * Thrift transport or to OAuth/OIDC token requests. + * + * When `path` contains `?o=` (SPOG account-level routing), + * the driver automatically injects an `x-databricks-org-id` header unless + * one is already present in this map. + */ + customHeaders?: Record; + /** * Whether the driver emits telemetry events (connection / statement / * cloud-fetch / error). Defaults to `true`. diff --git a/lib/telemetry/DatabricksTelemetryExporter.ts b/lib/telemetry/DatabricksTelemetryExporter.ts index 65fe8b64..bfa0f1b1 100644 --- a/lib/telemetry/DatabricksTelemetryExporter.ts +++ b/lib/telemetry/DatabricksTelemetryExporter.ts @@ -249,6 +249,7 @@ export default class DatabricksTelemetryExporter { let headers: Record = { 'Content-Type': 'application/json', 'User-Agent': userAgent, + ...(config.customHeaders ?? {}), }; if (authenticatedExport) { diff --git a/lib/telemetry/FeatureFlagCache.ts b/lib/telemetry/FeatureFlagCache.ts index 47fc9e8d..e060bd93 100644 --- a/lib/telemetry/FeatureFlagCache.ts +++ b/lib/telemetry/FeatureFlagCache.ts @@ -140,6 +140,7 @@ export default class FeatureFlagCache { const headers: Record = { 'Content-Type': 'application/json', 'User-Agent': this.userAgent, + ...(this.context.getConfig().customHeaders ?? {}), ...(await this.getAuthHeaders()), }; diff --git a/lib/telemetry/TelemetryClientProvider.ts b/lib/telemetry/TelemetryClientProvider.ts index 4e7c09c0..eb873cde 100644 --- a/lib/telemetry/TelemetryClientProvider.ts +++ b/lib/telemetry/TelemetryClientProvider.ts @@ -133,7 +133,11 @@ class TelemetryClientProvider { return; } - holder.client.unregisterContext(context); + // Skip unregister on the last release so close()'s final flush can still + // resolve auth/connection providers from the FIFO snapshot. + if (holder.refCount > 1) { + holder.client.unregisterContext(context); + } holder.refCount -= 1; logger.log(LogLevel.debug, `TelemetryClient reference count for ${host}: ${holder.refCount}`); diff --git a/tests/unit/DBSQLClient.test.ts b/tests/unit/DBSQLClient.test.ts index 184e25a3..1b04fe7e 100644 --- a/tests/unit/DBSQLClient.test.ts +++ b/tests/unit/DBSQLClient.test.ts @@ -103,6 +103,18 @@ describe('DBSQLClient.connect', () => { logSpy.restore(); }); + + it('populates config.customHeaders with org-id parsed from ?o= (SPOG)', async () => { + const client = new DBSQLClient(); + await client.connect({ ...connectOptions, path: '/sql/1.0/warehouses/abc?o=12345678901234' }); + expect(client.getConfig().customHeaders).to.deep.equal({ 'x-databricks-org-id': '12345678901234' }); + }); + + it('leaves config.customHeaders undefined when path has no ?o= and none supplied', async () => { + const client = new DBSQLClient(); + await client.connect({ ...connectOptions, path: '/sql/1.0/warehouses/abc' }); + expect(client.getConfig().customHeaders).to.be.undefined; + }); }); describe('DBSQLClient.openSession', () => { @@ -755,33 +767,140 @@ describe('DBSQLClient telemetry paths', () => { describe('extractWorkspaceId', () => { it('returns the numeric o= param from httpPath', () => { - const client = new DBSQLClient(); - (client as any).httpPath = '/sql/1.0/warehouses/abc?o=12345678901234'; - const id = (client as any).extractWorkspaceId(); + const id = (DBSQLClient as any).extractWorkspaceId('/sql/1.0/warehouses/abc?o=12345678901234'); expect(id).to.equal('12345678901234'); }); it('returns undefined when no query string', () => { - const client = new DBSQLClient(); - (client as any).httpPath = '/sql/1.0/warehouses/abc'; - expect((client as any).extractWorkspaceId()).to.be.undefined; + expect((DBSQLClient as any).extractWorkspaceId('/sql/1.0/warehouses/abc')).to.be.undefined; }); it('returns undefined when o= is not numeric', () => { - const client = new DBSQLClient(); - (client as any).httpPath = '/sql/1.0/warehouses/abc?o=tenant_xyz'; - expect((client as any).extractWorkspaceId()).to.be.undefined; + expect((DBSQLClient as any).extractWorkspaceId('/sql/1.0/warehouses/abc?o=tenant_xyz')).to.be.undefined; }); it('handles o= as a non-first param', () => { - const client = new DBSQLClient(); - (client as any).httpPath = '/sql/1.0/warehouses/abc?foo=bar&o=42&baz=qux'; - expect((client as any).extractWorkspaceId()).to.equal('42'); + expect((DBSQLClient as any).extractWorkspaceId('/sql/1.0/warehouses/abc?foo=bar&o=42&baz=qux')).to.equal('42'); }); it('returns undefined when httpPath is unset', () => { + expect((DBSQLClient as any).extractWorkspaceId(undefined)).to.be.undefined; + }); + + it('returns the numeric workspace id from all-purpose cluster path form', () => { + expect((DBSQLClient as any).extractWorkspaceId('sql/protocolv1/o/99999999999999/0101-000000-aaaaaaaa')).to.equal( + '99999999999999', + ); + }); + + it('returns the numeric workspace id from all-purpose cluster path with leading slash', () => { + expect((DBSQLClient as any).extractWorkspaceId('/sql/protocolv1/o/12345/0101-000000-aaaaaaaa')).to.equal('12345'); + }); + + it('returns undefined when all-purpose cluster path has non-numeric workspace segment', () => { + expect((DBSQLClient as any).extractWorkspaceId('sql/protocolv1/o/tenant_xyz/0101-000000-aaaaaaaa')).to.be + .undefined; + }); + + it('prefers ?o= query form over /o/ path form when both are present', () => { + expect((DBSQLClient as any).extractWorkspaceId('sql/protocolv1/o/111/cluster?o=222')).to.equal('222'); + }); + }); + + describe('buildCustomHeaders (SPOG)', () => { + it('injects x-databricks-org-id from ?o= in httpPath', () => { const client = new DBSQLClient(); - expect((client as any).extractWorkspaceId()).to.be.undefined; + const headers = (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=12345678901234', undefined); + expect(headers).to.deep.equal({ 'x-databricks-org-id': '12345678901234' }); + }); + + it('returns undefined when no ?o= and no user-supplied customHeaders', () => { + const client = new DBSQLClient(); + const headers = (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc', undefined); + expect(headers).to.be.undefined; + }); + + it('preserves user-supplied customHeaders alongside parsed org-id', () => { + const client = new DBSQLClient(); + const headers = (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=42', { 'x-trace-id': 'tid-001' }); + expect(headers).to.deep.equal({ 'x-trace-id': 'tid-001', 'x-databricks-org-id': '42' }); + }); + + it('user-supplied x-databricks-org-id wins over ?o= parsed value (case-insensitive)', () => { + const client = new DBSQLClient(); + const headers = (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=42', { + 'X-Databricks-Org-Id': '999', + }); + expect(headers).to.deep.equal({ 'X-Databricks-Org-Id': '999' }); + }); + + it('does not inject org-id when ?o= value is non-numeric', () => { + const client = new DBSQLClient(); + const headers = (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=tenant_xyz', undefined); + expect(headers).to.be.undefined; + }); + + it('injects x-databricks-org-id from all-purpose cluster path form', () => { + const client = new DBSQLClient(); + const headers = (client as any).buildCustomHeaders( + 'sql/protocolv1/o/99999999999999/0101-000000-aaaaaaaa', + undefined, + ); + expect(headers).to.deep.equal({ 'x-databricks-org-id': '99999999999999' }); + }); + + it('logs a warning when workspace ID segment is non-numeric (path form)', () => { + const client = new DBSQLClient(); + const logSpy = sinon.spy((client as any).logger, 'log'); + try { + (client as any).buildCustomHeaders('sql/protocolv1/o/tenant_xyz/cluster', undefined); + const warnCalls = logSpy.getCalls().filter((c) => c.args[0] === LogLevel.warn); + expect(warnCalls).to.have.lengthOf(1); + expect(warnCalls[0].args[1]).to.match(/non-numeric workspace ID/); + } finally { + logSpy.restore(); + } + }); + + it('logs a warning when ?o= is present but non-numeric', () => { + const client = new DBSQLClient(); + const logSpy = sinon.spy((client as any).logger, 'log'); + try { + (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=tenant_xyz', undefined); + const warnCalls = logSpy.getCalls().filter((c) => c.args[0] === LogLevel.warn); + expect(warnCalls).to.have.lengthOf(1); + expect(warnCalls[0].args[1]).to.match(/non-numeric workspace ID/); + } finally { + logSpy.restore(); + } + }); + + it('logs a debug line when injecting org-id from httpPath', () => { + const client = new DBSQLClient(); + const logSpy = sinon.spy((client as any).logger, 'log'); + try { + (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=42', undefined); + const injectLog = logSpy + .getCalls() + .find((c) => c.args[0] === LogLevel.debug && /injecting x-databricks-org-id=42/.test(String(c.args[1]))); + expect(injectLog, 'expected SPOG inject debug log').to.exist; + } finally { + logSpy.restore(); + } + }); + + it('logs a debug line when caller supplies x-databricks-org-id', () => { + const client = new DBSQLClient(); + const logSpy = sinon.spy((client as any).logger, 'log'); + try { + (client as any).buildCustomHeaders('/sql/1.0/warehouses/abc?o=42', { 'x-databricks-org-id': '999' }); + const callerLog = logSpy + .getCalls() + .find((c) => c.args[0] === LogLevel.debug && /supplied by caller/.test(String(c.args[1]))); + expect(callerLog, 'expected SPOG caller-supplied debug log').to.exist; + } finally { + logSpy.restore(); + } }); }); diff --git a/tests/unit/telemetry/DatabricksTelemetryExporter.test.ts b/tests/unit/telemetry/DatabricksTelemetryExporter.test.ts index fb347bf6..c3fd099d 100644 --- a/tests/unit/telemetry/DatabricksTelemetryExporter.test.ts +++ b/tests/unit/telemetry/DatabricksTelemetryExporter.test.ts @@ -120,6 +120,46 @@ describe('DatabricksTelemetryExporter', () => { } expect(threw).to.be.false; }); + + it('should attach config.customHeaders to the POST (SPOG)', async () => { + const context = new ClientContextStub({ + customHeaders: { 'x-databricks-org-id': '12345678901234' }, + } as any); + const registry = new CircuitBreakerRegistry(context); + const exporter = new DatabricksTelemetryExporter(context, 'host.example.com', registry, fakeAuthProvider); + const sendRequestStub = sinon.stub(exporter as any, 'sendRequest').returns(makeOkResponse()); + + await exporter.export([makeMetric()]); + + const init = sendRequestStub.firstCall.args[1] as { headers: Record }; + expect(init.headers['x-databricks-org-id']).to.equal('12345678901234'); + }); + + it('auth headers win over customHeaders on key collision', async () => { + const context = new ClientContextStub({ + customHeaders: { Authorization: 'Bearer not-the-real-token' }, + } as any); + const registry = new CircuitBreakerRegistry(context); + const exporter = new DatabricksTelemetryExporter(context, 'host.example.com', registry, fakeAuthProvider); + const sendRequestStub = sinon.stub(exporter as any, 'sendRequest').returns(makeOkResponse()); + + await exporter.export([makeMetric()]); + + const init = sendRequestStub.firstCall.args[1] as { headers: Record }; + expect(init.headers.Authorization).to.equal('Bearer test-token'); + }); + + it('does not attach customHeaders when none are configured', async () => { + const context = new ClientContextStub(); + const registry = new CircuitBreakerRegistry(context); + const exporter = new DatabricksTelemetryExporter(context, 'host.example.com', registry, fakeAuthProvider); + const sendRequestStub = sinon.stub(exporter as any, 'sendRequest').returns(makeOkResponse()); + + await exporter.export([makeMetric()]); + + const init = sendRequestStub.firstCall.args[1] as { headers: Record }; + expect(init.headers).to.not.have.property('x-databricks-org-id'); + }); }); describe('export() - retry logic', () => { diff --git a/tests/unit/telemetry/FeatureFlagCache.test.ts b/tests/unit/telemetry/FeatureFlagCache.test.ts index ed7bc79c..c12ca5b1 100644 --- a/tests/unit/telemetry/FeatureFlagCache.test.ts +++ b/tests/unit/telemetry/FeatureFlagCache.test.ts @@ -317,4 +317,43 @@ describe('FeatureFlagCache', () => { fetchStub.restore(); }); }); + + describe('customHeaders propagation (SPOG)', () => { + function makeJsonResponse(body: unknown) { + return Promise.resolve({ + ok: true, + status: 200, + statusText: 'OK', + json: () => Promise.resolve(body), + text: () => Promise.resolve(''), + }); + } + + it('attaches config.customHeaders to the feature-flag GET', async () => { + const context = new ClientContextStub({ + customHeaders: { 'x-databricks-org-id': '12345678901234' }, + } as any); + const cache = new FeatureFlagCache(context); + const stub = sinon.stub(cache as any, 'fetchWithRetry').returns(makeJsonResponse({ flags: [] })); + + await (cache as any).fetchFeatureFlag('host.example.com'); + + expect(stub.calledOnce).to.be.true; + const init = stub.firstCall.args[1] as { headers: Record }; + expect(init.headers['x-databricks-org-id']).to.equal('12345678901234'); + stub.restore(); + }); + + it('does not set x-databricks-org-id when customHeaders is empty', async () => { + const context = new ClientContextStub(); + const cache = new FeatureFlagCache(context); + const stub = sinon.stub(cache as any, 'fetchWithRetry').returns(makeJsonResponse({ flags: [] })); + + await (cache as any).fetchFeatureFlag('host.example.com'); + + const init = stub.firstCall.args[1] as { headers: Record }; + expect(init.headers).to.not.have.property('x-databricks-org-id'); + stub.restore(); + }); + }); }); diff --git a/tests/unit/telemetry/TelemetryClientProvider.test.ts b/tests/unit/telemetry/TelemetryClientProvider.test.ts index 00f638fe..6ffa4db2 100644 --- a/tests/unit/telemetry/TelemetryClientProvider.test.ts +++ b/tests/unit/telemetry/TelemetryClientProvider.test.ts @@ -267,6 +267,50 @@ describe('TelemetryClientProvider', () => { expect(second.isClosed()).to.be.false; expect(provider.getRefCount(HOST1)).to.equal(1); }); + + it('keeps the context registered through close() on last refcount', async () => { + // Regression: releaseClient used to call unregisterContext before + // close(), which dropped the only FIFO entry and left the final + // flush without an auth provider — metrics were dropped with + // "missing Authorization header". + const context = new ClientContextStub(); + const provider = new TelemetryClientProvider(); + + const client = provider.getOrCreateClient(context, HOST1); + let authProviderAtClose: unknown = 'unset'; + sinon.stub(client, 'close').callsFake(async () => { + // The TelemetryClient is its own IClientContext. While close() runs, + // getAuthProvider() must still resolve (the FIFO entry survives). + authProviderAtClose = client.getAuthProvider?.(); + }); + + await provider.releaseClient(context, HOST1); + + // The FIFO walk hits the (still-registered) context; the stub's + // getAuthProvider returns undefined but the lookup itself completes. + // What we're really asserting is that the FIFO wasn't pre-emptied: + // pre-fix, this assertion would not even fire because close() runs + // against an empty FIFO and the auth-provider walk short-circuits + // before close()'s callsFake. Here we verify the spy ran AND the + // context is still registered at that moment. + expect(authProviderAtClose).to.not.equal('unset'); + expect((client as any).contexts.length).to.equal(1); + }); + + it('drops the context immediately when other refcounts remain', async () => { + const context = new ClientContextStub(); + const provider = new TelemetryClientProvider(); + + const client = provider.getOrCreateClient(context, HOST1); + provider.getOrCreateClient(context, HOST1); // refcount=2 + + await provider.releaseClient(context, HOST1); + + // Multi-refcount path: unregisterContext runs immediately; the + // (single) FIFO entry tracking this context was removed. + expect((client as any).contexts.length).to.equal(0); + expect(provider.getRefCount(HOST1)).to.equal(1); + }); }); describe('Host normalization', () => {