Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions packages/client/src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ export async function auth(
fetchFn?: FetchLike;
}
): Promise<AuthResult> {
// Validate clientMetadataUrl before any network calls
validateClientMetadataUrl(provider.clientMetadataUrl);

try {
return await authInternal(provider, options);
} catch (error) {
Expand Down Expand Up @@ -614,6 +617,29 @@ async function authInternal(
return 'REDIRECT';
}

/**
* Validates that the given `clientMetadataUrl` is a valid HTTPS URL with a non-root pathname.
*
* This function is a no-op when `url` is `undefined` or empty (providers that do not use
* URL-based client IDs are unaffected). When the value is defined but invalid, it throws
* an {@linkcode OAuthError} with code {@linkcode OAuthErrorCode.InvalidClientMetadata}.
*
* SDK transports and the `auth()` function call this automatically. Consumers who implement
* {@linkcode OAuthClientProvider} directly can call this in their own constructors for
* early validation.
*
* @param url - The `clientMetadataUrl` value to validate (from `OAuthClientProvider.clientMetadataUrl`)
* @throws {OAuthError} When `url` is defined but is not a valid HTTPS URL with a non-root pathname
*/
export function validateClientMetadataUrl(url: string | undefined): void {
if (url && !isHttpsUrl(url)) {
throw new OAuthError(
OAuthErrorCode.InvalidClientMetadata,
`clientMetadataUrl must be a valid HTTPS URL with a non-root pathname, got: ${url}`
);
}
}

/**
* SEP-991: URL-based Client IDs
* Validate that the `client_id` is a valid URL with `https` scheme
Expand Down
11 changes: 7 additions & 4 deletions packages/client/src/client/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { FetchLike } from '@modelcontextprotocol/core';

import type { OAuthClientProvider } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError, validateClientMetadataUrl } from './auth.js';

/**
* Middleware function that wraps and enhances fetch functionality.
Expand Down Expand Up @@ -35,9 +35,11 @@ export type Middleware = (next: FetchLike) => FetchLike;
* @param baseUrl - Base URL for OAuth server discovery (defaults to request URL domain)
* @returns A fetch middleware function
*/
export const withOAuth =
(provider: OAuthClientProvider, baseUrl?: string | URL): Middleware =>
next => {
export const withOAuth = (provider: OAuthClientProvider, baseUrl?: string | URL): Middleware => {
// Validate clientMetadataUrl at factory creation time (fail-fast)
validateClientMetadataUrl(provider.clientMetadataUrl);

return next => {
return async (input, init) => {
const makeRequest = async (): Promise<Response> => {
const headers = new Headers(init?.headers);
Expand Down Expand Up @@ -95,6 +97,7 @@ export const withOAuth =
return response;
};
};
};

/**
* Logger function type for HTTP requests
Expand Down
7 changes: 6 additions & 1 deletion packages/client/src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type { ErrorEvent, EventSourceInit } from 'eventsource';
import { EventSource } from 'eventsource';

import type { AuthResult, OAuthClientProvider } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError, validateClientMetadataUrl } from './auth.js';

export class SseError extends Error {
constructor(
Expand Down Expand Up @@ -81,6 +81,11 @@ export class SSEClientTransport implements Transport {
onmessage?: (message: JSONRPCMessage) => void;

constructor(url: URL, opts?: SSEClientTransportOptions) {
// Validate clientMetadataUrl at construction time (fail-fast)
if (opts?.authProvider) {
validateClientMetadataUrl(opts.authProvider.clientMetadataUrl);
}

this._url = url;
this._resourceMetadataUrl = undefined;
this._scope = undefined;
Expand Down
7 changes: 6 additions & 1 deletion packages/client/src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
import { EventSourceParserStream } from 'eventsource-parser/stream';

import type { AuthResult, OAuthClientProvider } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError, validateClientMetadataUrl } from './auth.js';

// Default reconnection options for StreamableHTTP connections
const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = {
Expand Down Expand Up @@ -147,6 +147,11 @@ export class StreamableHTTPClientTransport implements Transport {
onmessage?: (message: JSONRPCMessage) => void;

constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) {
// Validate clientMetadataUrl at construction time (fail-fast)
if (opts?.authProvider) {
validateClientMetadataUrl(opts.authProvider.clientMetadataUrl);
}

this._url = url;
this._resourceMetadataUrl = undefined;
this._scope = undefined;
Expand Down
104 changes: 103 additions & 1 deletion packages/client/test/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import {
refreshAuthorization,
registerClient,
selectClientAuthMethod,
startAuthorization
startAuthorization,
validateClientMetadataUrl
} from '../../src/client/auth.js';
import { createPrivateKeyJwtAuth } from '../../src/client/authExtensions.js';

Expand Down Expand Up @@ -1993,6 +1994,28 @@ describe('OAuth Authorization', () => {
vi.clearAllMocks();
});

it('throws OAuthError with InvalidClientMetadata before any network request when clientMetadataUrl is invalid', async () => {
const providerWithInvalidUrl: OAuthClientProvider = {
...mockProvider,
clientMetadataUrl: 'http://invalid.example.com/metadata'
};

await expect(
auth(providerWithInvalidUrl, {
serverUrl: 'https://server.example.com'
})
).rejects.toThrow(OAuthError);

await expect(
auth(providerWithInvalidUrl, {
serverUrl: 'https://server.example.com'
})
).rejects.toMatchObject({ code: OAuthErrorCode.InvalidClientMetadata });

// Verify no fetch was called
expect(mockFetch).not.toHaveBeenCalled();
});

it('performs client_credentials with private_key_jwt when provider has addClientAuthentication', async () => {
// Arrange: metadata discovery for PRM and AS
mockFetch.mockImplementation(url => {
Expand Down Expand Up @@ -3406,6 +3429,85 @@ describe('OAuth Authorization', () => {
});
});

describe('validateClientMetadataUrl', () => {
it('passes for valid HTTPS URL with path', () => {
expect(() => validateClientMetadataUrl('https://client.example.com/.well-known/oauth-client')).not.toThrow();
});

it('passes for valid HTTPS URL with multi-segment path', () => {
expect(() => validateClientMetadataUrl('https://example.com/clients/metadata.json')).not.toThrow();
});

it('throws OAuthError for HTTP URL', () => {
expect(() => validateClientMetadataUrl('http://client.example.com/.well-known/oauth-client')).toThrow(OAuthError);
try {
validateClientMetadataUrl('http://client.example.com/.well-known/oauth-client');
} catch (error) {
expect(error).toBeInstanceOf(OAuthError);
expect((error as OAuthError).code).toBe(OAuthErrorCode.InvalidClientMetadata);
expect((error as OAuthError).message).toContain('http://client.example.com/.well-known/oauth-client');
}
});

it('throws OAuthError for non-URL string', () => {
expect(() => validateClientMetadataUrl('not-a-url')).toThrow(OAuthError);
try {
validateClientMetadataUrl('not-a-url');
} catch (error) {
expect(error).toBeInstanceOf(OAuthError);
expect((error as OAuthError).code).toBe(OAuthErrorCode.InvalidClientMetadata);
expect((error as OAuthError).message).toContain('not-a-url');
}
});

it('passes silently for empty string', () => {
// Empty string is falsy, so it passes through without throwing
expect(() => validateClientMetadataUrl('')).not.toThrow();
});

it('throws OAuthError for root-path HTTPS URL with trailing slash', () => {
expect(() => validateClientMetadataUrl('https://client.example.com/')).toThrow(OAuthError);
try {
validateClientMetadataUrl('https://client.example.com/');
} catch (error) {
expect(error).toBeInstanceOf(OAuthError);
expect((error as OAuthError).code).toBe(OAuthErrorCode.InvalidClientMetadata);
expect((error as OAuthError).message).toContain('https://client.example.com/');
}
});

it('throws OAuthError for root-path HTTPS URL without trailing slash', () => {
expect(() => validateClientMetadataUrl('https://client.example.com')).toThrow(OAuthError);
try {
validateClientMetadataUrl('https://client.example.com');
} catch (error) {
expect(error).toBeInstanceOf(OAuthError);
expect((error as OAuthError).code).toBe(OAuthErrorCode.InvalidClientMetadata);
expect((error as OAuthError).message).toContain('https://client.example.com');
}
});

it('passes silently for undefined', () => {
expect(() => validateClientMetadataUrl(undefined)).not.toThrow();
});

it('passes silently when called with no arguments', () => {
// eslint-disable-next-line unicorn/no-useless-undefined
expect(() => validateClientMetadataUrl(undefined)).not.toThrow();
});

it('error message matches expected format', () => {
try {
validateClientMetadataUrl('http://example.com/path');
} catch (error) {
expect(error).toBeInstanceOf(OAuthError);
expect((error as OAuthError).message).toBe(
'clientMetadataUrl must be a valid HTTPS URL with a non-root pathname, got: http://example.com/path'
);
}
});
});

describe('SEP-991: URL-based Client ID fallback logic', () => {
const validClientMetadata = {
redirect_uris: ['http://localhost:3000/callback'],
Expand Down
32 changes: 32 additions & 0 deletions packages/client/test/client/middleware.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { FetchLike } from '@modelcontextprotocol/core';
import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core';
import type { Mocked, MockedFunction, MockInstance } from 'vitest';

import type { OAuthClientProvider } from '../../src/client/auth.js';
Expand Down Expand Up @@ -44,6 +45,37 @@ describe('withOAuth', () => {
mockFetch = vi.fn();
});

describe('clientMetadataUrl validation', () => {
it('throws OAuthError at factory call when provider.clientMetadataUrl is an invalid HTTP URL', () => {
const providerWithInvalidUrl: Mocked<OAuthClientProvider> = {
...mockProvider,
clientMetadataUrl: 'http://example.com/metadata'
};

expect(() => withOAuth(providerWithInvalidUrl, 'https://server.example.com')).toThrow(OAuthError);
try {
withOAuth(providerWithInvalidUrl, 'https://server.example.com');
} catch (error) {
expect(error).toBeInstanceOf(OAuthError);
expect((error as OAuthError).code).toBe(OAuthErrorCode.InvalidClientMetadata);
expect((error as OAuthError).message).toContain('http://example.com/metadata');
}
});

it('succeeds when clientMetadataUrl is a valid HTTPS URL', () => {
const providerWithValidUrl: Mocked<OAuthClientProvider> = {
...mockProvider,
clientMetadataUrl: 'https://client.example.com/.well-known/oauth-client'
};

expect(() => withOAuth(providerWithValidUrl, 'https://server.example.com')).not.toThrow();
});

it('succeeds when clientMetadataUrl is undefined', () => {
expect(() => withOAuth(mockProvider, 'https://server.example.com')).not.toThrow();
});
});

it('should add Authorization header when tokens are available (with explicit baseUrl)', async () => {
mockProvider.tokens.mockResolvedValue({
access_token: 'test-token',
Expand Down
61 changes: 61 additions & 0 deletions packages/client/test/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,67 @@ import type { OAuthClientProvider } from '../../src/client/auth.js';
import { UnauthorizedError } from '../../src/client/auth.js';
import { SSEClientTransport } from '../../src/client/sse.js';

describe('SSEClientTransport clientMetadataUrl validation', () => {
it('throws OAuthError when authProvider.clientMetadataUrl is an invalid HTTP URL', () => {
const provider: OAuthClientProvider = {
get redirectUrl() {
return 'http://localhost/callback';
},
get clientMetadata() {
return { redirect_uris: ['http://localhost/callback'] };
},
clientMetadataUrl: 'http://example.com/metadata',
clientInformation: vi.fn(),
tokens: vi.fn(),
saveTokens: vi.fn(),
redirectToAuthorization: vi.fn(),
saveCodeVerifier: vi.fn(),
codeVerifier: vi.fn()
};

expect(() => new SSEClientTransport(new URL('https://server.example.com/sse'), { authProvider: provider })).toThrow(OAuthError);
});

it('succeeds when clientMetadataUrl is a valid HTTPS URL', () => {
const provider: OAuthClientProvider = {
get redirectUrl() {
return 'http://localhost/callback';
},
get clientMetadata() {
return { redirect_uris: ['http://localhost/callback'] };
},
clientMetadataUrl: 'https://client.example.com/.well-known/oauth-client',
clientInformation: vi.fn(),
tokens: vi.fn(),
saveTokens: vi.fn(),
redirectToAuthorization: vi.fn(),
saveCodeVerifier: vi.fn(),
codeVerifier: vi.fn()
};

expect(() => new SSEClientTransport(new URL('https://server.example.com/sse'), { authProvider: provider })).not.toThrow();
});

it('succeeds when clientMetadataUrl is undefined', () => {
const provider: OAuthClientProvider = {
get redirectUrl() {
return 'http://localhost/callback';
},
get clientMetadata() {
return { redirect_uris: ['http://localhost/callback'] };
},
clientInformation: vi.fn(),
tokens: vi.fn(),
saveTokens: vi.fn(),
redirectToAuthorization: vi.fn(),
saveCodeVerifier: vi.fn(),
codeVerifier: vi.fn()
};

expect(() => new SSEClientTransport(new URL('https://server.example.com/sse'), { authProvider: provider })).not.toThrow();
});
});

/**
* Parses HTTP Basic auth from a request's Authorization header.
* Returns the decoded client_id and client_secret, or undefined if the header is absent or malformed.
Expand Down
Loading
Loading