diff --git a/auth/authorization_code.go b/auth/authorization_code.go index ac51ea12..1389ec2e 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -11,6 +11,7 @@ import ( "crypto/rand" "errors" "fmt" + "io" "net/http" "net/url" "slices" @@ -20,18 +21,6 @@ import ( "golang.org/x/oauth2" ) -// ClientSecretAuthConfig is used to configure client authentication using client_secret. -// Authentication method will be selected based on the authorization server's supported methods, -// according to the following preference order: -// 1. client_secret_post -// 2. client_secret_basic -type ClientSecretAuthConfig struct { - // ClientID is the client ID to be used for client authentication. - ClientID string - // ClientSecret is the client secret to be used for client authentication. - ClientSecret string -} - // ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document // based client registration per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. @@ -42,14 +31,6 @@ type ClientIDMetadataDocumentConfig struct { URL string } -// PreregisteredClientConfig is used to configure a pre-registered client per -// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. -// Currently only "client_secret_basic" and "client_secret_post" authentication methods are supported. -type PreregisteredClientConfig struct { - // ClientSecretAuthConfig is the client_secret based configuration to be used for client authentication. - ClientSecretAuthConfig *ClientSecretAuthConfig -} - // DynamicClientRegistrationConfig is used to configure dynamic client registration per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration. type DynamicClientRegistrationConfig struct { @@ -67,12 +48,18 @@ type AuthorizationResult struct { State string } -// AuthorizationArgs is the input to [AuthorizationCodeHandlerConfig].AuthorizationCodeFetcher. +// AuthorizationArgs is the input to [AuthorizationCodeFetcher]. type AuthorizationArgs struct { // Authorization URL to be opened in a browser for the user to start the authorization process. URL string } +// AuthorizationCodeFetcher is called to initiate the OAuth authorization flow. +// It is responsible for directing the user to the authorization URL (e.g., opening +// in a browser) and returning the authorization code and state once the Authorization +// Server redirects back to the configured RedirectURL. +type AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) + // AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeHandler]. type AuthorizationCodeHandlerConfig struct { // Client registration configuration. @@ -82,7 +69,7 @@ type AuthorizationCodeHandlerConfig struct { // 3. Dynamic Client Registration // At least one method must be configured. ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig - PreregisteredClientConfig *PreregisteredClientConfig + PreregisteredClient *oauthex.ClientCredentials DynamicClientRegistrationConfig *DynamicClientRegistrationConfig // RedirectURL is a required URL to redirect to after authorization. @@ -97,10 +84,8 @@ type AuthorizationCodeHandlerConfig struct { RedirectURL string // AuthorizationCodeFetcher is a required function called to initiate the authorization flow. - // It is responsible for opening the URL in a browser for the user to start the authorization process. - // It should return the authorization code and state once the Authorization Server - // redirects back to the RedirectURL. - AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) + // See [AuthorizationCodeFetcher] for details. + AuthorizationCodeFetcher AuthorizationCodeFetcher // Client is an optional HTTP client to use for HTTP requests. // It is used for the following requests: @@ -127,8 +112,6 @@ type AuthorizationCodeHandler struct { var _ OAuthHandler = (*AuthorizationCodeHandler)(nil) -func (h *AuthorizationCodeHandler) isOAuthHandler() {} - func (h *AuthorizationCodeHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { return h.tokenSource, nil } @@ -141,7 +124,7 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho return nil, errors.New("config must be provided") } if config.ClientIDMetadataDocumentConfig == nil && - config.PreregisteredClientConfig == nil && + config.PreregisteredClient == nil && config.DynamicClientRegistrationConfig == nil { return nil, errors.New("at least one client registration configuration must be provided") } @@ -151,13 +134,9 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") } - preCfg := config.PreregisteredClientConfig - if preCfg != nil { - if preCfg.ClientSecretAuthConfig == nil { - return nil, errors.New("ClientSecretAuthConfig is required for pre-registered client") - } - if preCfg.ClientSecretAuthConfig.ClientID == "" || preCfg.ClientSecretAuthConfig.ClientSecret == "" { - return nil, fmt.Errorf("pre-registered client ID or secret is empty") + if config.PreregisteredClient != nil { + if err := config.PreregisteredClient.Validate(); err != nil { + return nil, fmt.Errorf("invalid PreregisteredClient configuration: %w", err) } } dCfg := config.DynamicClientRegistrationConfig @@ -198,6 +177,7 @@ func isNonRootHTTPSURL(u string) bool { // On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() + defer io.Copy(io.Discard, resp.Body) wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) if err != nil { @@ -218,9 +198,20 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ return err } - asm, err := h.getAuthServerMetadata(ctx, prm) + asm, err := GetAuthServerMetadata(ctx, prm.AuthorizationServers[0], h.config.Client) if err != nil { - return err + return fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm == nil { + // Fallback to 2025-03-26 spec: predefined endpoints. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery + authServerURL := prm.AuthorizationServers[0] + asm = &oauthex.AuthServerMeta{ + Issuer: authServerURL, + AuthorizationEndpoint: authServerURL + "/authorize", + TokenEndpoint: authServerURL + "/token", + RegistrationEndpoint: authServerURL + "/register", + } } resolvedClientConfig, err := h.handleRegistration(ctx, asm) @@ -365,70 +356,6 @@ func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { return urls } -// getAuthServerMetadata returns the authorization server metadata. -// The provided Protected Resource Metadata must not be nil and must contain -// at least one authorization server. -// It returns an error if the metadata request fails with non-4xx HTTP status code -// or the fetched metadata fails security checks. -// If no metadata was found, it returns a minimal set of endpoints -// as a fallback to 2025-03-26 spec. -func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) { - authServerURL := prm.AuthorizationServers[0] - for _, u := range authorizationServerMetadataURLs(authServerURL) { - asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, h.config.Client) - if err != nil { - return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) - } - if asm != nil { - return asm, nil - } - } - - // Fallback to 2025-03-26 spec: predefined endpoints. - // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery - asm := &oauthex.AuthServerMeta{ - Issuer: authServerURL, - AuthorizationEndpoint: authServerURL + "/authorize", - TokenEndpoint: authServerURL + "/token", - RegistrationEndpoint: authServerURL + "/register", - } - return asm, nil -} - -// authorizationServerMetadataURLs returns a list of URLs to try when looking for -// authorization server metadata as mandated by the MCP specification: -// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. -func authorizationServerMetadataURLs(issuerURL string) []string { - var urls []string - - baseURL, err := url.Parse(issuerURL) - if err != nil { - return nil - } - - if baseURL.Path == "" { - // "OAuth 2.0 Authorization Server Metadata". - baseURL.Path = "/.well-known/oauth-authorization-server" - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0". - baseURL.Path = "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) - return urls - } - - originalPath := baseURL.Path - // "OAuth 2.0 Authorization Server Metadata with path insertion". - baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 with path insertion". - baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 with path appending". - baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) - return urls -} - type registrationType int const ( @@ -491,13 +418,17 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * }, nil } // 2. Attempt to use pre-registered client configuration. - pCfg := h.config.PreregisteredClientConfig - if pCfg != nil { + preCfg := h.config.PreregisteredClient + if preCfg != nil { authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported) + clientSecret := "" + if preCfg.ClientSecretAuth != nil { + clientSecret = preCfg.ClientSecretAuth.ClientSecret + } return &resolvedClientConfig{ registrationType: registrationTypePreregistered, - clientID: pCfg.ClientSecretAuthConfig.ClientID, - clientSecret: pCfg.ClientSecretAuthConfig.ClientSecret, + clientID: preCfg.ClientID, + clientSecret: clientSecret, authStyle: authStyle, }, nil } diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index d371cba9..879031ad 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -46,9 +46,9 @@ func TestAuthorize(t *testing.T) { handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{ RedirectURL: "http://localhost:12345/callback", - PreregisteredClientConfig: &PreregisteredClientConfig{ - ClientSecretAuthConfig: &ClientSecretAuthConfig{ - ClientID: "test_client_id", + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "test_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ ClientSecret: "test_client_secret", }, }, @@ -154,9 +154,9 @@ func TestNewAuthorizationCodeHandler_Success(t *testing.T) { { name: "PreregisteredClientConfig", config: &AuthorizationCodeHandlerConfig{ - PreregisteredClientConfig: &PreregisteredClientConfig{ - ClientSecretAuthConfig: &ClientSecretAuthConfig{ - ClientID: "test_client_id", + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "test_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ ClientSecret: "test_client_secret", }, }, @@ -223,7 +223,7 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { config: func() *AuthorizationCodeHandlerConfig { cfg := validConfig() cfg.ClientIDMetadataDocumentConfig = nil - cfg.PreregisteredClientConfig = nil + cfg.PreregisteredClient = nil cfg.DynamicClientRegistrationConfig = nil return cfg }, @@ -256,7 +256,7 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { name: "InvalidPreregistered_MissingSecretConfig", config: func() *AuthorizationCodeHandlerConfig { cfg := validConfig() - cfg.PreregisteredClientConfig = &PreregisteredClientConfig{} + cfg.PreregisteredClient = &oauthex.ClientCredentials{} return cfg }, }, @@ -264,8 +264,9 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { name: "InvalidPreregistered_EmptyID", config: func() *AuthorizationCodeHandlerConfig { cfg := validConfig() - cfg.PreregisteredClientConfig = &PreregisteredClientConfig{ - ClientSecretAuthConfig: &ClientSecretAuthConfig{ + cfg.PreregisteredClient = &oauthex.ClientCredentials{ + ClientID: "", + ClientSecretAuth: &oauthex.ClientSecretAuth{ ClientSecret: "secret", }, } @@ -276,9 +277,10 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { name: "InvalidPreregistered_EmptySecret", config: func() *AuthorizationCodeHandlerConfig { cfg := validConfig() - cfg.PreregisteredClientConfig = &PreregisteredClientConfig{ - ClientSecretAuthConfig: &ClientSecretAuthConfig{ - ClientID: "test_client_id", + cfg.PreregisteredClient = &oauthex.ClientCredentials{ + ClientID: "test_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "", }, } return cfg @@ -438,90 +440,6 @@ func TestGetProtectedResourceMetadata_Error(t *testing.T) { } } -func TestGetAuthServerMetadata(t *testing.T) { - handler, err := NewAuthorizationCodeHandler(validConfig()) - if err != nil { - t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) - } - - tests := []struct { - name string - issuerPath string - endpointConfig *oauthtest.MetadataEndpointConfig - }{ - { - name: "OAuthEndpoint_Root", - issuerPath: "", - endpointConfig: &oauthtest.MetadataEndpointConfig{ - ServeOAuthInsertedEndpoint: true, - }, - }, - { - name: "OpenIDEndpoint_Root", - issuerPath: "", - endpointConfig: &oauthtest.MetadataEndpointConfig{ - ServeOpenIDInsertedEndpoint: true, - }, - }, - { - name: "OAuthEndpoint_Path", - issuerPath: "/oauth", - endpointConfig: &oauthtest.MetadataEndpointConfig{ - ServeOAuthInsertedEndpoint: true, - }, - }, - { - name: "OpenIDEndpoint_Path", - issuerPath: "/openid", - endpointConfig: &oauthtest.MetadataEndpointConfig{ - ServeOpenIDInsertedEndpoint: true, - }, - }, - { - name: "OpenIDAppendedEndpoint_Path", - issuerPath: "/openid", - endpointConfig: &oauthtest.MetadataEndpointConfig{ - ServeOpenIDAppendedEndpoint: true, - }, - }, - { - name: "NoMetadata", - issuerPath: "", - endpointConfig: &oauthtest.MetadataEndpointConfig{ - // All metadata endpoints disabled. - ServeOAuthInsertedEndpoint: false, - ServeOpenIDInsertedEndpoint: false, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ - IssuerPath: tt.issuerPath, - MetadataEndpointConfig: tt.endpointConfig, - }) - s.Start(t) - issuerURL := s.URL() + tt.issuerPath - prm := &oauthex.ProtectedResourceMetadata{ - Resource: "https://example.com/resource", - AuthorizationServers: []string{issuerURL}, - } - - got, err := handler.getAuthServerMetadata(t.Context(), prm) - if err != nil { - t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) - } - if got == nil { - t.Fatal("getAuthServerMetadata() got nil, want metadata") - } - if got.Issuer != issuerURL { - t.Errorf("getAuthServerMetadata() issuer = %q, want %q", got.Issuer, issuerURL) - } - }) - } -} - func TestSelectTokenAuthMethod(t *testing.T) { tests := []struct { name string @@ -587,9 +505,9 @@ func TestHandleRegistration(t *testing.T) { }, }, handlerConfig: &AuthorizationCodeHandlerConfig{ - PreregisteredClientConfig: &PreregisteredClientConfig{ - ClientSecretAuthConfig: &ClientSecretAuthConfig{ - ClientID: "pre_client_id", + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ ClientSecret: "pre_client_secret", }, }, @@ -622,11 +540,9 @@ func TestHandleRegistration(t *testing.T) { if err != nil { t.Fatalf("NewAuthorizationCodeHandler() error = %v, want nil", err) } - asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ - AuthorizationServers: []string{s.URL()}, - }) + asm, err := GetAuthServerMetadata(t.Context(), s.URL(), http.DefaultClient) if err != nil { - t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + t.Fatalf("GetAuthServerMetadata() unexpected error = %v", err) } got, err := handler.handleRegistration(t.Context(), asm) if err != nil { @@ -672,11 +588,9 @@ func TestDynamicRegistration(t *testing.T) { if err != nil { t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) } - asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ - AuthorizationServers: []string{s.URL()}, - }) + asm, err := GetAuthServerMetadata(t.Context(), s.URL(), http.DefaultClient) if err != nil { - t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + t.Fatalf("GetAuthServerMetadata() unexpected error = %v", err) } got, err := handler.handleRegistration(t.Context(), asm) if err != nil { diff --git a/auth/client.go b/auth/client.go index 0af6963f..db32d97a 100644 --- a/auth/client.go +++ b/auth/client.go @@ -22,8 +22,6 @@ import ( // [github.com/modelcontextprotocol/go-sdk/mcp.StreamableClientTransport] // for an example. type OAuthHandler interface { - isOAuthHandler() - // TokenSource returns a token source to be used for outgoing requests. // Returned token source might be nil. In that case, the transport will not // add any authorization headers to the request. diff --git a/auth/extauth/enterprise_handler.go b/auth/extauth/enterprise_handler.go new file mode 100644 index 00000000..1d45d4a9 --- /dev/null +++ b/auth/extauth/enterprise_handler.go @@ -0,0 +1,256 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package extauth provides OAuth handler implementations for MCP authorization extensions. +// This package implements Enterprise Managed Authorization as defined in SEP-990. + +//go:build mcp_go_client_oauth + +package extauth + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// grantTypeJWTBearer is the grant type for RFC 7523 JWT Bearer authorization grant. +const grantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +// IDTokenFetcher is called to obtain an ID Token from the enterprise IdP. +// This is typically done via OIDC login flow where the user authenticates +// with their enterprise identity provider. +// +// Returns an oauth2.Token where Extra("id_token") contains the OpenID Connect ID Token (JWT). +type IDTokenFetcher func(ctx context.Context) (*oauth2.Token, error) + +// EnterpriseHandlerConfig is the configuration for [EnterpriseHandler]. +type EnterpriseHandlerConfig struct { + // IdP configuration (where the user authenticates) + + // IdPIssuerURL is the enterprise IdP's issuer URL (e.g., "https://acme.okta.com"). + // Used for OIDC discovery to find the token endpoint. + // REQUIRED. + IdPIssuerURL string + + // IdPCredentials contains the MCP Client's credentials registered at the IdP. + // REQUIRED. These credentials are used for token exchange at the IdP. + // The ClientID is always required. ClientSecretAuth is optional and only needed + // if the IdP requires client authentication (confidential clients). + IdPCredentials *oauthex.ClientCredentials + + // MCP Server configuration (the resource being accessed) + + // MCPAuthServerURL is the MCP Server's authorization server issuer URL. + // Used as the audience for token exchange and for metadata discovery. + // REQUIRED. + MCPAuthServerURL string + + // MCPResourceURI is the MCP Server's resource identifier (RFC 9728). + // Used as the resource parameter in token exchange. + // REQUIRED. + MCPResourceURI string + + // MCPCredentials contains the MCP Client's credentials registered at the MCP Server. + // REQUIRED. These credentials are used for JWT Bearer grant at the MCP Server. + // The ClientID is always required. ClientSecretAuth is optional and only needed + // if the MCP Server requires client authentication. + MCPCredentials *oauthex.ClientCredentials + + // MCPScopes is the list of scopes to request at the MCP Server. + // OPTIONAL. + MCPScopes []string + + // IDTokenFetcher is called to obtain an ID Token when authorization is needed. + // The implementation should handle the OIDC login flow (e.g., browser redirect, + // callback handling) and return the ID token. + // REQUIRED. + IDTokenFetcher IDTokenFetcher + + // HTTPClient is an optional HTTP client for customization. + // If nil, http.DefaultClient is used. + // OPTIONAL. + HTTPClient *http.Client +} + +// EnterpriseHandler is an implementation of [auth.OAuthHandler] that uses +// Enterprise Managed Authorization (SEP-990) to obtain access tokens. +// +// The flow consists of: +// 1. OIDC Login: User authenticates with enterprise IdP → ID Token +// 2. Token Exchange (RFC 8693): ID Token → ID-JAG at IdP +// 3. JWT Bearer Grant (RFC 7523): ID-JAG → Access Token at MCP Server +type EnterpriseHandler struct { + config *EnterpriseHandlerConfig + + // tokenSource is the token source obtained after authorization. + tokenSource oauth2.TokenSource +} + +// Compile-time check that EnterpriseHandler implements auth.OAuthHandler. +var _ auth.OAuthHandler = (*EnterpriseHandler)(nil) + +// NewEnterpriseHandler creates a new EnterpriseHandler. +// It performs validation of the configuration and returns an error if invalid. +func NewEnterpriseHandler(config *EnterpriseHandlerConfig) (*EnterpriseHandler, error) { + if config == nil { + return nil, errors.New("config must be provided") + } + if config.IdPIssuerURL == "" { + return nil, errors.New("IdPIssuerURL is required") + } + if config.IdPCredentials == nil { + return nil, errors.New("IdPCredentials is required") + } + if err := config.IdPCredentials.Validate(); err != nil { + return nil, fmt.Errorf("invalid IdPCredentials: %w", err) + } + if config.MCPAuthServerURL == "" { + return nil, errors.New("MCPAuthServerURL is required") + } + if config.MCPResourceURI == "" { + return nil, errors.New("MCPResourceURI is required") + } + if config.MCPCredentials == nil { + return nil, errors.New("MCPCredentials is required") + } + if err := config.MCPCredentials.Validate(); err != nil { + return nil, fmt.Errorf("invalid MCPCredentials: %w", err) + } + if config.IDTokenFetcher == nil { + return nil, errors.New("IDTokenFetcher is required") + } + return &EnterpriseHandler{config: config}, nil +} + +// TokenSource returns the token source for outgoing requests. +// Returns nil if authorization has not been performed yet. +func (h *EnterpriseHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// Authorize performs the Enterprise Managed Authorization flow. +// It is called when a request fails with 401 or 403. +func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer resp.Body.Close() + defer io.Copy(io.Discard, resp.Body) + + httpClient := h.config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Step 1: Get ID Token via the configured fetcher (e.g., OIDC login) + oidcToken, err := h.config.IDTokenFetcher(ctx) + if err != nil { + return fmt.Errorf("failed to obtain ID token: %w", err) + } + + // Extract ID token from the oauth2.Token + idToken, ok := oidcToken.Extra("id_token").(string) + if !ok || idToken == "" { + return fmt.Errorf("id_token not found in OIDC token response") + } + + // Step 2: Discover IdP token endpoint via OIDC discovery + idpMeta, err := auth.GetAuthServerMetadata(ctx, h.config.IdPIssuerURL, httpClient) + if err != nil { + return fmt.Errorf("failed to discover IdP metadata: %w", err) + } + if idpMeta == nil { + return fmt.Errorf("no authorization server metadata found for IdP: %s", h.config.IdPIssuerURL) + } + + // Step 3: Token Exchange (ID Token → ID-JAG) + tokenExchangeReq := &oauthex.TokenExchangeRequest{ + RequestedTokenType: oauthex.TokenTypeIDJAG, + Audience: h.config.MCPAuthServerURL, + Resource: h.config.MCPResourceURI, + Scope: h.config.MCPScopes, + SubjectToken: idToken, + SubjectTokenType: oauthex.TokenTypeIDToken, + } + + idJAGToken, err := oauthex.ExchangeToken( + ctx, + idpMeta.TokenEndpoint, + tokenExchangeReq, + h.config.IdPCredentials, + httpClient, + ) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + + // Step 4: Discover MCP Server token endpoint + mcpMeta, err := auth.GetAuthServerMetadata(ctx, h.config.MCPAuthServerURL, httpClient) + if err != nil { + return fmt.Errorf("failed to discover MCP auth server metadata: %w", err) + } + if mcpMeta == nil { + return fmt.Errorf("no authorization server metadata found for MCP server: %s", h.config.MCPAuthServerURL) + } + + // Step 5: JWT Bearer Grant (ID-JAG → Access Token) + // The ID-JAG is in the AccessToken field of the token (despite the name) + accessToken, err := exchangeJWTBearer( + ctx, + mcpMeta.TokenEndpoint, + idJAGToken.AccessToken, + h.config.MCPCredentials, + httpClient, + ) + if err != nil { + return fmt.Errorf("JWT bearer grant failed: %w", err) + } + + // Store the token source for subsequent requests + h.tokenSource = oauth2.StaticTokenSource(accessToken) + return nil +} + +// exchangeJWTBearer exchanges an Identity Assertion JWT Authorization Grant (ID-JAG) +// for an access token using JWT Bearer Grant per RFC 7523. +func exchangeJWTBearer( + ctx context.Context, + tokenEndpoint string, + assertion string, + clientCreds *oauthex.ClientCredentials, + httpClient *http.Client, +) (*oauth2.Token, error) { + cfg := &oauth2.Config{ + ClientID: clientCreds.ClientID, + Endpoint: oauth2.Endpoint{ + TokenURL: tokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + } + // Set ClientSecret if ClientSecretAuth is configured + if clientCreds.ClientSecretAuth != nil { + cfg.ClientSecret = clientCreds.ClientSecretAuth.ClientSecret + } + + if httpClient == nil { + httpClient = http.DefaultClient + } + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + token, err := cfg.Exchange( + ctxWithClient, + "", + oauth2.SetAuthURLParam("grant_type", grantTypeJWTBearer), + oauth2.SetAuthURLParam("assertion", assertion), + ) + if err != nil { + return nil, fmt.Errorf("JWT bearer grant request failed: %w", err) + } + + return token, nil +} diff --git a/auth/extauth/enterprise_handler_test.go b/auth/extauth/enterprise_handler_test.go new file mode 100644 index 00000000..08c042b1 --- /dev/null +++ b/auth/extauth/enterprise_handler_test.go @@ -0,0 +1,497 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package extauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// TestNewEnterpriseHandler_Validation tests validation in NewEnterpriseHandler. +func TestNewEnterpriseHandler_Validation(t *testing.T) { + validConfig := &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + }, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp_client_id", + }, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { + token := &oauth2.Token{ + AccessToken: "mock_access_token", + TokenType: "Bearer", + } + return token.WithExtra(map[string]interface{}{"id_token": "mock_id_token"}), nil + }, + } + + tests := []struct { + name string + config *EnterpriseHandlerConfig + wantError string + }{ + { + name: "nil config", + config: nil, + wantError: "config must be provided", + }, + { + name: "missing IdPIssuerURL", + config: &EnterpriseHandlerConfig{ + IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "IdPIssuerURL is required", + }, + { + name: "nil IdPCredentials", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: nil, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "IdPCredentials is required", + }, + { + name: "invalid IdPCredentials - empty ClientID", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "", // Invalid - empty + }, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "invalid IdPCredentials", + }, + { + name: "invalid IdPCredentials - empty ClientSecret in ClientSecretAuth", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "", // Invalid - empty secret when ClientSecretAuth is set + }, + }, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "invalid IdPCredentials", + }, + { + name: "missing MCPAuthServerURL", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + MCPAuthServerURL: "", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "MCPAuthServerURL is required", + }, + { + name: "missing MCPResourceURI", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "MCPResourceURI is required", + }, + { + name: "nil MCPCredentials", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: nil, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "MCPCredentials is required", + }, + { + name: "invalid MCPCredentials - empty ClientID", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "", // Invalid - empty + }, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, + }, + wantError: "invalid MCPCredentials", + }, + { + name: "missing IDTokenFetcher", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, + IDTokenFetcher: nil, + }, + wantError: "IDTokenFetcher is required", + }, + { + name: "valid config - public clients (no ClientSecretAuth)", + config: validConfig, + wantError: "", + }, + { + name: "valid config - confidential clients (with ClientSecretAuth)", + config: &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "idp_secret", + }, + }, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "mcp_secret", + }, + }, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { + token := &oauth2.Token{ + AccessToken: "mock_access_token", + TokenType: "Bearer", + } + return token.WithExtra(map[string]interface{}{"id_token": "mock_id_token"}), nil + }, + }, + wantError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler, err := NewEnterpriseHandler(tt.config) + if tt.wantError != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantError) + } + if !strings.Contains(err.Error(), tt.wantError) { + t.Fatalf("expected error containing %q, got %v", tt.wantError, err) + } + } else { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if handler == nil { + t.Fatal("expected handler to be non-nil") + } + } + }) + } +} + +// TestEnterpriseHandler_Authorize_E2E tests the complete enterprise authorization flow. +func TestEnterpriseHandler_Authorize_E2E(t *testing.T) { + // Set up IdP (Identity Provider) fake server with token exchange support + idpServer := setupIdPServer(t) + + // Set up MCP authorization server with JWT bearer grant support + mcpAuthServer := setupMCPAuthServer(t) + + // Create enterprise handler + handler, err := NewEnterpriseHandler(&EnterpriseHandlerConfig{ + IdPIssuerURL: idpServer.URL, + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "idp_secret", + }, + }, + MCPAuthServerURL: mcpAuthServer.URL, + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "mcp_secret", + }, + }, + MCPScopes: []string{"read", "write"}, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { + token := &oauth2.Token{ + AccessToken: "mock_access_token", + TokenType: "Bearer", + } + return token.WithExtra(map[string]interface{}{"id_token": "mock_id_token_from_user_login"}), nil + }, + }) + if err != nil { + t.Fatalf("NewEnterpriseHandler failed: %v", err) + } + + // Simulate a 401 response from MCP server + req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/api", nil) + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + + // Perform authorization + if err := handler.Authorize(context.Background(), req, resp); err != nil { + t.Fatalf("Authorize failed: %v", err) + } + + // Verify token source is set + tokenSource, err := handler.TokenSource(context.Background()) + if err != nil { + t.Fatalf("TokenSource failed: %v", err) + } + if tokenSource == nil { + t.Fatal("expected token source to be set after authorization") + } + + // Verify we can get a token + token, err := tokenSource.Token() + if err != nil { + t.Fatalf("Token() failed: %v", err) + } + if token.AccessToken != "mcp_access_token_from_jwt_bearer" { + t.Errorf("unexpected access token: got %q, want %q", + token.AccessToken, "mcp_access_token_from_jwt_bearer") + } +} + +// setupIdPServer creates a fake IdP server that supports token exchange. +func setupIdPServer(t *testing.T) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + + var server *httptest.Server + + // OAuth/OIDC metadata endpoint - uses closure to get server URL + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": server.URL, + "token_endpoint": server.URL + "/token", + "authorization_endpoint": server.URL + "/authorize", + "code_challenge_methods_supported": []string{"S256"}, + }) + }) + + // Token endpoint - supports token exchange (RFC 8693) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + + grantType := r.Form.Get("grant_type") + if grantType != oauthex.GrantTypeTokenExchange { + http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) + return + } + + // Verify client authentication + clientID := r.Form.Get("client_id") + clientSecret := r.Form.Get("client_secret") + if clientID != "idp_client_id" || clientSecret != "idp_secret" { + http.Error(w, "invalid client credentials", http.StatusUnauthorized) + return + } + + // Verify token exchange parameters + if r.Form.Get("requested_token_type") != oauthex.TokenTypeIDJAG { + http.Error(w, "invalid requested_token_type", http.StatusBadRequest) + return + } + if r.Form.Get("subject_token_type") != oauthex.TokenTypeIDToken { + http.Error(w, "invalid subject_token_type", http.StatusBadRequest) + return + } + if r.Form.Get("subject_token") == "" { + http.Error(w, "missing subject_token", http.StatusBadRequest) + return + } + + // Return ID-JAG (Identity Assertion JWT Authorization Grant) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "id-jag-token-from-idp", + "issued_token_type": oauthex.TokenTypeIDJAG, + "token_type": "N_A", + "expires_in": 300, + "scope": "read write", + }) + }) + + server = httptest.NewServer(mux) + t.Cleanup(server.Close) + + return server +} + +// setupMCPAuthServer creates a fake MCP authorization server that supports JWT bearer grant. +func setupMCPAuthServer(t *testing.T) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + + var server *httptest.Server + + // OAuth metadata endpoint - uses closure to get server URL + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": server.URL, + "token_endpoint": server.URL + "/token", + "code_challenge_methods_supported": []string{"S256"}, + }) + }) + + // Token endpoint - supports JWT bearer grant (RFC 7523) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + + grantType := r.Form.Get("grant_type") + if grantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" { + http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) + return + } + + // Verify client authentication + clientID := r.Form.Get("client_id") + clientSecret := r.Form.Get("client_secret") + if clientID != "mcp_client_id" || clientSecret != "mcp_secret" { + http.Error(w, "invalid client credentials", http.StatusUnauthorized) + return + } + + // Verify assertion (ID-JAG) + assertion := r.Form.Get("assertion") + if assertion != "id-jag-token-from-idp" { + http.Error(w, "invalid assertion", http.StatusBadRequest) + return + } + + // Return access token + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "mcp_access_token_from_jwt_bearer", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "read write", + }) + }) + + server = httptest.NewServer(mux) + t.Cleanup(server.Close) + + return server +} + +// TestEnterpriseHandler_Authorize_IDTokenFetcherError tests error handling when IDTokenFetcher fails. +func TestEnterpriseHandler_Authorize_IDTokenFetcherError(t *testing.T) { + handler, err := NewEnterpriseHandler(&EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + }, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp_client_id", + }, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { + return nil, fmt.Errorf("user cancelled login") + }, + }) + if err != nil { + t.Fatalf("NewEnterpriseHandler failed: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/api", nil) + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + + err = handler.Authorize(context.Background(), req, resp) + if err == nil { + t.Fatal("expected error from Authorize, got nil") + } + if !strings.Contains(err.Error(), "failed to obtain ID token") { + t.Errorf("expected error about ID token, got: %v", err) + } +} + +// TestEnterpriseHandler_TokenSource_BeforeAuthorization tests TokenSource before authorization. +func TestEnterpriseHandler_TokenSource_BeforeAuthorization(t *testing.T) { + handler, err := NewEnterpriseHandler(&EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + }, + MCPAuthServerURL: "https://mcp-auth.example.com", + MCPResourceURI: "https://mcp.example.com", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp_client_id", + }, + IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { + token := &oauth2.Token{ + AccessToken: "mock_access_token", + TokenType: "Bearer", + } + return token.WithExtra(map[string]interface{}{"id_token": "mock_id_token"}), nil + }, + }) + if err != nil { + t.Fatalf("NewEnterpriseHandler failed: %v", err) + } + + tokenSource, err := handler.TokenSource(context.Background()) + if err != nil { + t.Fatalf("TokenSource failed: %v", err) + } + if tokenSource != nil { + t.Errorf("expected nil token source before authorization, got %v", tokenSource) + } +} diff --git a/auth/extauth/oidc_login.go b/auth/extauth/oidc_login.go new file mode 100644 index 00000000..0bf8e6f6 --- /dev/null +++ b/auth/extauth/oidc_login.go @@ -0,0 +1,221 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements OIDC Authorization Code flow for obtaining ID tokens +// as part of Enterprise Managed Authorization (SEP-990). +// See https://openid.net/specs/openid-connect-core-1_0.html + +//go:build mcp_go_client_oauth + +package extauth + +import ( + "context" + "crypto/rand" + "fmt" + "net/http" + "slices" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// OIDCLoginConfig configures the OIDC Authorization Code flow for obtaining +// an ID Token. This is used with [PerformOIDCLogin] to authenticate users +// with an enterprise IdP before calling the Enterprise Managed Authorization flow. +type OIDCLoginConfig struct { + // IssuerURL is the IdP's issuer URL (e.g., "https://acme.okta.com"). + // REQUIRED. + IssuerURL string + // Credentials contains the MCP Client's credentials registered at the IdP. + // The ClientID field is REQUIRED. The ClientSecret field is OPTIONAL + // (only required if the client is confidential, not a public client). + // REQUIRED (struct itself), but ClientSecret field can be empty. + Credentials *oauthex.ClientCredentials + // RedirectURL is the OAuth2 redirect URI registered with the IdP. + // This must match exactly what was registered with the IdP. + // REQUIRED. + RedirectURL string + // Scopes are the OAuth2/OIDC scopes to request. + // "openid" is REQUIRED for OIDC. Common values: ["openid", "profile", "email"] + // REQUIRED. + Scopes []string + // LoginHint is a hint to the IdP about the user's identity. + // Some IdPs may require this (e.g., as an email address for routing to SSO providers). + // Example: "user@example.com" + // OPTIONAL. + LoginHint string + // HTTPClient is the HTTP client for making requests. + // If nil, http.DefaultClient is used. + // OPTIONAL. + HTTPClient *http.Client +} + +// PerformOIDCLogin performs the complete OIDC Authorization Code flow with PKCE +// in a single function call. This is the recommended approach for obtaining an +// ID Token for use with [EnterpriseHandler]. +// +// Returns an oauth2.Token where: +// - Extra("id_token") contains the OpenID Connect ID Token (JWT) +// - AccessToken contains the OAuth2 access token (if issued by IdP) +// - RefreshToken contains the OAuth2 refresh token (if issued by IdP) +// - TokenType is the token type (typically "Bearer") +// - Expiry is when the token expires +func PerformOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, + authCodeFetcher auth.AuthorizationCodeFetcher, +) (*oauth2.Token, error) { + if authCodeFetcher == nil { + return nil, fmt.Errorf("authCodeFetcher is required") + } + + authReq, oauth2Config, err := initiateOIDCLogin(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to initiate OIDC login: %w", err) + } + + authResult, err := authCodeFetcher(ctx, &auth.AuthorizationArgs{URL: authReq.authURL}) + if err != nil { + return nil, fmt.Errorf("failed to fetch authorization code: %w", err) + } + + if authResult.State != authReq.state { + return nil, fmt.Errorf("state mismatch: expected %q, got %q", authReq.state, authResult.State) + } + + tokens, err := completeOIDCLogin(ctx, config, oauth2Config, authResult.Code, authReq.codeVerifier) + if err != nil { + return nil, fmt.Errorf("failed to complete OIDC login: %w", err) + } + + return tokens, nil +} + +// oidcAuthorizationRequest holds internal state for OIDC authorization. +type oidcAuthorizationRequest struct { + authURL string + state string + codeVerifier string +} + +// initiateOIDCLogin initiates an OIDC Authorization Code flow with PKCE. +func initiateOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, +) (*oidcAuthorizationRequest, *oauth2.Config, error) { + if config == nil { + return nil, nil, fmt.Errorf("config is required") + } + if config.IssuerURL == "" { + return nil, nil, fmt.Errorf("IssuerURL is required") + } + if config.Credentials == nil || config.Credentials.ClientID == "" { + return nil, nil, fmt.Errorf("Credentials.ClientID is required") + } + if config.RedirectURL == "" { + return nil, nil, fmt.Errorf("RedirectURL is required") + } + if len(config.Scopes) == 0 { + return nil, nil, fmt.Errorf("Scopes is required (must include 'openid')") + } + + if !slices.Contains(config.Scopes, "openid") { + return nil, nil, fmt.Errorf("Scopes must include 'openid' for OIDC") + } + + if err := oauthex.CheckURLScheme(config.IssuerURL); err != nil { + return nil, nil, fmt.Errorf("invalid IssuerURL: %w", err) + } + if err := oauthex.CheckURLScheme(config.RedirectURL); err != nil { + return nil, nil, fmt.Errorf("invalid RedirectURL: %w", err) + } + + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + meta, err := auth.GetAuthServerMetadata(ctx, config.IssuerURL, httpClient) + if err != nil { + return nil, nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) + } + if meta == nil { + return nil, nil, fmt.Errorf("no authorization server metadata found for OIDC issuer: %s", config.IssuerURL) + } + if meta.AuthorizationEndpoint == "" { + return nil, nil, fmt.Errorf("authorization_endpoint not found in OIDC metadata") + } + + codeVerifier := oauth2.GenerateVerifier() + state := rand.Text() + + oauth2Config := &oauth2.Config{ + ClientID: config.Credentials.ClientID, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: meta.AuthorizationEndpoint, + TokenURL: meta.TokenEndpoint, + }, + } + // Set ClientSecret if ClientSecretAuth is configured + if config.Credentials.ClientSecretAuth != nil { + oauth2Config.ClientSecret = config.Credentials.ClientSecretAuth.ClientSecret + } + + authURLOpts := []oauth2.AuthCodeOption{ + oauth2.S256ChallengeOption(codeVerifier), + } + if config.LoginHint != "" { + authURLOpts = append(authURLOpts, oauth2.SetAuthURLParam("login_hint", config.LoginHint)) + } + authURL := oauth2Config.AuthCodeURL(state, authURLOpts...) + + return &oidcAuthorizationRequest{ + authURL: authURL, + state: state, + codeVerifier: codeVerifier, + }, oauth2Config, nil +} + +// completeOIDCLogin completes the OIDC Authorization Code flow by exchanging +// the authorization code for tokens. +func completeOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, + oauth2Config *oauth2.Config, + authCode string, + codeVerifier string, +) (*oauth2.Token, error) { + if authCode == "" { + return nil, fmt.Errorf("authCode is required") + } + if codeVerifier == "" { + return nil, fmt.Errorf("codeVerifier is required") + } + + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + oauth2Token, err := oauth2Config.Exchange( + ctxWithClient, + authCode, + oauth2.VerifierOption(codeVerifier), + ) + if err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + + // Validate that id_token is present in the response + idToken, ok := oauth2Token.Extra("id_token").(string) + if !ok || idToken == "" { + return nil, fmt.Errorf("id_token not found in token response") + } + + return oauth2Token, nil +} diff --git a/auth/extauth/oidc_login_test.go b/auth/extauth/oidc_login_test.go new file mode 100644 index 00000000..298158ba --- /dev/null +++ b/auth/extauth/oidc_login_test.go @@ -0,0 +1,505 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package extauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TestInitiateOIDCLogin tests the OIDC authorization request generation. +func TestInitiateOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServer(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + Credentials: &oauthex.ClientCredentials{ + ClientID: "test-client", + }, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + t.Run("successful initiation", func(t *testing.T) { + authReq, _, err := initiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("initiateOIDCLogin failed: %v", err) + } + // Validate authURL + if authReq.authURL == "" { + t.Error("authURL is empty") + } + // Parse and validate URL parameters + u, err := url.Parse(authReq.authURL) + if err != nil { + t.Fatalf("Failed to parse authURL: %v", err) + } + q := u.Query() + if q.Get("response_type") != "code" { + t.Errorf("expected response_type 'code', got '%s'", q.Get("response_type")) + } + if q.Get("client_id") != "test-client" { + t.Errorf("expected client_id 'test-client', got '%s'", q.Get("client_id")) + } + if q.Get("redirect_uri") != "http://localhost:8080/callback" { + t.Errorf("expected redirect_uri 'http://localhost:8080/callback', got '%s'", q.Get("redirect_uri")) + } + if q.Get("scope") != "openid profile email" { + t.Errorf("expected scope 'openid profile email', got '%s'", q.Get("scope")) + } + if q.Get("code_challenge_method") != "S256" { + t.Errorf("expected code_challenge_method 'S256', got '%s'", q.Get("code_challenge_method")) + } + // Validate state is generated + if authReq.state == "" { + t.Error("state is empty") + } + if q.Get("state") != authReq.state { + t.Errorf("state in URL doesn't match returned state") + } + // Validate PKCE parameters + if authReq.codeVerifier == "" { + t.Error("codeVerifier is empty") + } + if q.Get("code_challenge") == "" { + t.Error("code_challenge is empty") + } + }) + t.Run("with login_hint", func(t *testing.T) { + configWithHint := *config + configWithHint.LoginHint = "user@example.com" + authReq, _, err := initiateOIDCLogin(context.Background(), &configWithHint) + if err != nil { + t.Fatalf("initiateOIDCLogin failed: %v", err) + } + u, err := url.Parse(authReq.authURL) + if err != nil { + t.Fatalf("Failed to parse authURL: %v", err) + } + q := u.Query() + if q.Get("login_hint") != "user@example.com" { + t.Errorf("expected login_hint 'user@example.com', got '%s'", q.Get("login_hint")) + } + }) + t.Run("without login_hint", func(t *testing.T) { + authReq, _, err := initiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("initiateOIDCLogin failed: %v", err) + } + u, err := url.Parse(authReq.authURL) + if err != nil { + t.Fatalf("Failed to parse authURL: %v", err) + } + q := u.Query() + if q.Has("login_hint") { + t.Errorf("expected no login_hint parameter, but got '%s'", q.Get("login_hint")) + } + }) + t.Run("nil config", func(t *testing.T) { + _, _, err := initiateOIDCLogin(context.Background(), nil) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + t.Run("missing openid scope", func(t *testing.T) { + badConfig := *config + badConfig.Scopes = []string{"profile", "email"} // Missing "openid" + _, _, err := initiateOIDCLogin(context.Background(), &badConfig) + if err == nil { + t.Error("expected error for missing openid scope, got nil") + } + if !strings.Contains(err.Error(), "openid") { + t.Errorf("expected error about missing 'openid', got: %v", err) + } + }) + t.Run("missing required fields", func(t *testing.T) { + tests := []struct { + name string + mutate func(*OIDCLoginConfig) + expectErr string + }{ + { + name: "missing IssuerURL", + mutate: func(c *OIDCLoginConfig) { c.IssuerURL = "" }, + expectErr: "IssuerURL is required", + }, + { + name: "missing ClientID", + mutate: func(c *OIDCLoginConfig) { c.Credentials.ClientID = "" }, + expectErr: "ClientID is required", + }, + { + name: "missing RedirectURL", + mutate: func(c *OIDCLoginConfig) { + c.RedirectURL = "" + // Ensure ClientID is present to test RedirectURL validation + c.Credentials = &oauthex.ClientCredentials{ClientID: "test"} + }, + expectErr: "RedirectURL is required", + }, + { + name: "missing Scopes", + mutate: func(c *OIDCLoginConfig) { + c.Scopes = nil + // Ensure required fields are present to test Scopes validation + c.Credentials = &oauthex.ClientCredentials{ClientID: "test"} + c.RedirectURL = "http://localhost:8080/callback" + }, + expectErr: "Scopes is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + badConfig := *config + tt.mutate(&badConfig) + _, _, err := initiateOIDCLogin(context.Background(), &badConfig) + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + } + }) + } + }) +} + +// TestCompleteOIDCLogin tests the authorization code exchange. +func TestCompleteOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + Credentials: &oauthex.ClientCredentials{ + ClientID: "test-client", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "test-secret", + }, + }, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + t.Run("successful code exchange", func(t *testing.T) { + // First initiate to get oauth2Config + _, oauth2Config, err := initiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("initiateOIDCLogin failed: %v", err) + } + + token, err := completeOIDCLogin( + context.Background(), + config, + oauth2Config, + "test-auth-code", + "test-code-verifier", + ) + if err != nil { + t.Fatalf("completeOIDCLogin failed: %v", err) + } + // Validate tokens + idToken, ok := token.Extra("id_token").(string) + if !ok || idToken == "" { + t.Error("id_token is missing or empty") + } + if token.AccessToken == "" { + t.Error("AccessToken is empty") + } + if token.TokenType != "Bearer" { + t.Errorf("expected TokenType 'Bearer', got '%s'", token.TokenType) + } + if token.Expiry.IsZero() { + t.Error("Expiry is zero") + } + }) + t.Run("missing parameters", func(t *testing.T) { + _, oauth2Config, _ := initiateOIDCLogin(context.Background(), config) + + tests := []struct { + name string + authCode string + codeVerifier string + expectErr string + }{ + { + name: "missing authCode", + authCode: "", + codeVerifier: "test-verifier", + expectErr: "authCode is required", + }, + { + name: "missing codeVerifier", + authCode: "test-code", + codeVerifier: "", + expectErr: "codeVerifier is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := completeOIDCLogin( + context.Background(), + config, + oauth2Config, + tt.authCode, + tt.codeVerifier, + ) + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + } + }) + } + }) +} + +// TestOIDCLoginE2E tests the complete OIDC login flow end-to-end. +func TestOIDCLoginE2E(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + Credentials: &oauthex.ClientCredentials{ + ClientID: "test-client", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "test-secret", + }, + }, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + // Step 1: Initiate login + authReq, oauth2Config, err := initiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("initiateOIDCLogin failed: %v", err) + } + // Step 2: Simulate user authentication and redirect + // (In real flow, user would visit authReq.authURL and IdP would redirect back) + // Here we just use a mock authorization code + mockAuthCode := "mock-authorization-code" + // Step 3: Complete login with authorization code + token, err := completeOIDCLogin( + context.Background(), + config, + oauth2Config, + mockAuthCode, + authReq.codeVerifier, + ) + if err != nil { + t.Fatalf("completeOIDCLogin failed: %v", err) + } + // Validate we got an ID token + idToken, ok := token.Extra("id_token").(string) + if !ok || idToken == "" { + t.Error("Expected ID token, got empty or missing") + } + // Validate ID token is a JWT (has 3 parts) + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + t.Errorf("Expected JWT with 3 parts, got %d parts", len(parts)) + } +} + +// createMockOIDCServer creates a mock OIDC server for testing initiateOIDCLogin. +func createMockOIDCServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, + "authorization_endpoint": serverURL + "/authorize", + "token_endpoint": serverURL + "/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{"authorization_code"}, + }) + return + } + http.NotFound(w, r) + })) + serverURL = server.URL + return server +} + +// createMockOIDCServerWithToken creates a mock OIDC server that also handles token exchange. +func createMockOIDCServerWithToken(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, + "authorization_endpoint": serverURL + "/authorize", + "token_endpoint": serverURL + "/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{"authorization_code"}, + }) + return + } + // Handle token endpoint + if r.URL.Path == "/token" { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + // Validate grant type + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + // Create mock ID token (JWT) + now := time.Now().Unix() + idToken := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.%s.mock-signature", + base64EncodeClaims(map[string]interface{}{ + "iss": serverURL, + "sub": "test-user", + "aud": "test-client", + "exp": now + 3600, + "iat": now, + "email": "test@example.com", + })) + // Return token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "mock-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "mock-refresh-token", + "id_token": idToken, + }) + return + } + http.NotFound(w, r) + })) + serverURL = server.URL + return server +} + +// base64EncodeClaims encodes JWT claims for testing. +func base64EncodeClaims(claims map[string]interface{}) string { + claimsJSON, _ := json.Marshal(claims) + return base64.RawURLEncoding.EncodeToString(claimsJSON) +} + +// TestPerformOIDCLogin tests the combined OIDC login flow with callback. +func TestPerformOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + Credentials: &oauthex.ClientCredentials{ + ClientID: "test-client", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "test-secret", + }, + }, + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + + t.Run("successful flow", func(t *testing.T) { + token, err := PerformOIDCLogin(context.Background(), config, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + // Validate authURL has required parameters + u, err := url.Parse(args.URL) + if err != nil { + return nil, fmt.Errorf("invalid authURL: %w", err) + } + q := u.Query() + if q.Get("response_type") != "code" { + return nil, fmt.Errorf("missing response_type") + } + if q.Get("state") == "" { + return nil, fmt.Errorf("missing state") + } + + // Simulate successful user authentication + return &auth.AuthorizationResult{ + Code: "mock-auth-code", + State: q.Get("state"), // Return the expected state from URL + }, nil + }) + + if err != nil { + t.Fatalf("PerformOIDCLogin failed: %v", err) + } + + idToken, ok := token.Extra("id_token").(string) + if !ok || idToken == "" { + t.Error("id_token is missing or empty") + } + if token.AccessToken == "" { + t.Error("AccessToken is empty") + } + }) + + t.Run("state mismatch", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), config, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + // Return wrong state to simulate CSRF attack + return &auth.AuthorizationResult{ + Code: "mock-auth-code", + State: "wrong-state", + }, nil + }) + + if err == nil { + t.Error("expected error for state mismatch, got nil") + } + if !strings.Contains(err.Error(), "state mismatch") { + t.Errorf("expected state mismatch error, got: %v", err) + } + }) + + t.Run("fetcher error", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), config, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return nil, fmt.Errorf("user cancelled") + }) + + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), "user cancelled") { + t.Errorf("expected 'user cancelled' error, got: %v", err) + } + }) + + t.Run("nil fetcher", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), config, nil) + if err == nil { + t.Error("expected error for nil fetcher, got nil") + } + if !strings.Contains(err.Error(), "authCodeFetcher is required") { + t.Errorf("expected 'authCodeFetcher is required' error, got: %v", err) + } + }) +} diff --git a/auth/shared.go b/auth/shared.go new file mode 100644 index 00000000..55837896 --- /dev/null +++ b/auth/shared.go @@ -0,0 +1,79 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains shared utilities for OAuth handlers. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "net/http" + "net/url" + "strings" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// GetAuthServerMetadata fetches authorization server metadata for the given issuer URL. +// It tries standard well-known endpoints (OAuth 2.0 and OIDC) and returns the first successful result. +// +// Returns (nil, nil) when no metadata endpoints respond (404s), allowing callers to implement +// fallback logic. Returns an error only for actual failures (network errors, invalid JSON, etc.). +func GetAuthServerMetadata(ctx context.Context, issuerURL string, httpClient *http.Client) (*oauthex.AuthServerMeta, error) { + var lastErr error + for _, metadataURL := range authorizationServerMetadataURLs(issuerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, metadataURL, issuerURL, httpClient) + if err != nil { + // Store the error but continue trying other endpoints + lastErr = err + continue + } + if asm != nil { + return asm, nil + } + } + // If we got actual errors (not just 404s), return the last error + // Otherwise return (nil, nil) to indicate no metadata found (fallback needed) + if lastErr != nil { + return nil, lastErr + } + return nil, nil +} + +// authorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func authorizationServerMetadataURLs(issuerURL string) []string { + var urls []string + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil + } + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls + } + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + + return urls +} diff --git a/auth/shared_test.go b/auth/shared_test.go new file mode 100644 index 00000000..9f51d7ed --- /dev/null +++ b/auth/shared_test.go @@ -0,0 +1,101 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "net/http" + "testing" + + "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" +) + +func TestGetAuthServerMetadata(t *testing.T) { + tests := []struct { + name string + issuerPath string + endpointConfig *oauthtest.MetadataEndpointConfig + wantNil bool + }{ + { + name: "OAuthEndpoint_Root", + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + }, + { + name: "OpenIDEndpoint_Root", + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDInsertedEndpoint: true, + }, + }, + { + name: "OAuthEndpoint_Path", + issuerPath: "/oauth", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + }, + { + name: "OpenIDEndpoint_Path", + issuerPath: "/openid", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDInsertedEndpoint: true, + }, + }, + { + name: "OpenIDAppendedEndpoint_Path", + issuerPath: "/openid", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDAppendedEndpoint: true, + }, + }, + { + name: "NoMetadata", + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + // All metadata endpoints disabled. + ServeOAuthInsertedEndpoint: false, + ServeOpenIDInsertedEndpoint: false, + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + IssuerPath: tt.issuerPath, + MetadataEndpointConfig: tt.endpointConfig, + }) + s.Start(t) + issuerURL := s.URL() + tt.issuerPath + + got, err := GetAuthServerMetadata(t.Context(), issuerURL, http.DefaultClient) + if tt.wantNil { + // When no metadata is found, GetAuthServerMetadata returns (nil, nil). + if err != nil { + t.Fatalf("GetAuthServerMetadata() unexpected error = %v, want nil", err) + } + if got != nil { + t.Fatal("GetAuthServerMetadata() expected nil for no metadata, got metadata") + } + return + } + if err != nil { + t.Fatalf("GetAuthServerMetadata() error = %v, want nil", err) + } + if got == nil { + t.Fatal("GetAuthServerMetadata() got nil, want metadata") + } + if got.Issuer != issuerURL { + t.Errorf("GetAuthServerMetadata() issuer = %q, want %q", got.Issuer, issuerURL) + } + }) + } +} diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 3b0c6592..6cddc958 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -98,9 +98,9 @@ func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]a // Try pre-registered client information if provided in the context. if clientID, ok := configCtx["client_id"].(string); ok { if clientSecret, ok := configCtx["client_secret"].(string); ok { - authConfig.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ - ClientSecretAuthConfig: &auth.ClientSecretAuthConfig{ - ClientID: clientID, + authConfig.PreregisteredClient = &oauthex.ClientCredentials{ + ClientID: clientID, + ClientSecretAuth: &oauthex.ClientSecretAuth{ ClientSecret: clientSecret, }, } diff --git a/docs/protocol.md b/docs/protocol.md index af92f931..72b901d9 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -354,8 +354,77 @@ client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, session, err := client.Connect(ctx, transport, nil) ``` -The `auth.AuthorizationCodeHandler` automatically manages token refreshing -and step-up authentication (when the server returns `insufficient_scope` error). +The `auth.AuthorizationCodeHandler` automatically manages token refreshing (if the server provides a refresh token) and step-up authentication (when the server returns `insufficient_scope` error). + +#### Enterprise Managed Authorization (SEP-990) + +For enterprise SSO scenarios where users authenticate with an enterprise Identity Provider (IdP), +the SDK provides +[`extauth.EnterpriseHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth/extauth#EnterpriseHandler), +an implementation of `OAuthHandler` that automates the Enterprise Managed Authorization flow: + +1. **OIDC Login**: User authenticates with enterprise IdP → ID Token +2. **Token Exchange** (RFC 8693): ID Token → ID-JAG at IdP +3. **JWT Bearer Grant** (RFC 7523): ID-JAG → Access Token at MCP Server + +To use enterprise managed authorization, create an `EnterpriseHandler` and assign it to your transport: + +```go +// Create ID token fetcher using OIDC login +idTokenFetcher := func(ctx context.Context) (*oauth2.Token, error) { + oidcConfig := &extauth.OIDCLoginConfig{ + IssuerURL: "https://company.okta.com", + Credentials: &oauthex.ClientCredentials{ + ClientID: "idp-client-id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "idp-client-secret", + }, + }, + RedirectURL: "http://localhost:3142", + Scopes: []string{"openid", "profile", "email"}, + } + + tokens, err := extauth.PerformOIDCLogin(ctx, oidcConfig, authCodeFetcher) + if err != nil { + return nil, err + } + + return tokens, nil +} + +// Create Enterprise Handler +enterpriseHandler, err := extauth.NewEnterpriseHandler(&extauth.EnterpriseHandlerConfig{ + IdPIssuerURL: "https://company.okta.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp-client-id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "idp-client-secret", + }, + }, + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURI: "https://mcp.mcpserver.example", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp-client-id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "mcp-client-secret", + }, + }, + MCPScopes: []string{"read", "write"}, + IDTokenFetcher: idTokenFetcher, +}) + +// Use with transport +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: enterpriseHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `EnterpriseHandler` automatically manages the token exchange flow. Note that it intentionally does not support refresh tokens - when an access token expires, the entire authorization flow is repeated to ensure enterprise policies are consistently enforced. + +For a complete working example, see [examples/auth/enterprise](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/auth/enterprise). ## Security @@ -574,3 +643,4 @@ func Example_progress() { // frobbing widgets 2/2 } ``` + diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index b9ddf4d7..a245a8aa 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -87,10 +87,11 @@ func main() { RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), AuthorizationCodeFetcher: receiver.getAuthorizationCode, // Uncomment the client configuration you want to use. - // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ - // ClientSecretAuthConfig: &auth.ClientSecretAuthConfig{ + // PreregisteredClient: &oauthex.ClientCredentials{ // ClientID: "", - // ClientSecret: "", + // ClientSecretAuth: &oauthex.ClientSecretAuth{ + // ClientSecret: "", + // }, // }, // }, // DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ diff --git a/examples/auth/enterprise/main.go b/examples/auth/enterprise/main.go new file mode 100644 index 00000000..4698f257 --- /dev/null +++ b/examples/auth/enterprise/main.go @@ -0,0 +1,229 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/auth/extauth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +var ( + // IdP (Identity Provider) configuration. + idpIssuerURL = flag.String("idp_issuer", "https://your-idp.okta.com", "IdP issuer URL (e.g., https://your-company.okta.com)") + idpClientID = flag.String("idp_client_id", "", "Client ID registered at the IdP") + idpClientSecret = flag.String("idp_client_secret", "", "Client secret at the IdP (optional for public clients)") + + // MCP Server configuration. + mcpServerURL = flag.String("mcp_server", "http://localhost:8000/mcp", "URL of the MCP server") + mcpAuthServerURL = flag.String("mcp_auth_server", "https://auth.mcpserver.example", "MCP server's authorization server URL") + mcpResourceURI = flag.String("mcp_resource_uri", "https://mcp.mcpserver.example", "MCP server's resource identifier (RFC 9728)") + mcpClientID = flag.String("mcp_client_id", "", "Client ID at the MCP server (optional)") + mcpClientSecret = flag.String("mcp_client_secret", "", "Client secret at the MCP server (optional)") + + // OAuth callback configuration. + callbackPort = flag.Int("callback_port", 3142, "Port for the local HTTP server that will receive the OAuth callback") +) + +// codeReceiver handles the OAuth callback from the IdP's authorization endpoint. +// It starts a local HTTP server to receive the authorization code after the user +// authenticates with their enterprise IdP. +type codeReceiver struct { + authChan chan *auth.AuthorizationResult + errChan chan error + listener net.Listener + server *http.Server +} + +// serveRedirectHandler starts an HTTP server to handle the OAuth redirect callback. +func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + // Extract the authorization code and state from the callback URL. + r.authChan <- &auth.AuthorizationResult{ + Code: req.URL.Query().Get("code"), + State: req.URL.Query().Get("state"), + } + fmt.Fprint(w, "Authentication successful. You can close this window.") + }) + + r.server = &http.Server{ + Addr: fmt.Sprintf("localhost:%d", *callbackPort), + Handler: mux, + } + if err := r.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.errChan <- err + } +} + +// getAuthorizationCode implements the AuthorizationCodeFetcher interface. +// It displays the authorization URL to the user and waits for the callback. +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + fmt.Printf("\nPlease open the following URL in your browser to authenticate:\n%s\n\n", args.URL) + select { + case authRes := <-r.authChan: + return authRes, nil + case err := <-r.errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// close shuts down the HTTP server. +func (r *codeReceiver) close() { + if r.server != nil { + r.server.Close() + } +} + +func main() { + flag.Parse() + + // Validate required configuration. + if *idpClientID == "" { + log.Fatal("--idp_client_id is required") + } + + // Set up the OAuth callback receiver. + receiver := &codeReceiver{ + authChan: make(chan *auth.AuthorizationResult), + errChan: make(chan error), + } + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *callbackPort)) + if err != nil { + log.Fatalf("failed to listen on port %d: %v", *callbackPort, err) + } + go receiver.serveRedirectHandler(listener) + defer receiver.close() + + log.Printf("OAuth callback server listening on http://localhost:%d", *callbackPort) + + // Create an ID Token fetcher that performs OIDC login with the enterprise IdP. + idTokenFetcher := func(ctx context.Context) (*oauth2.Token, error) { + log.Println("Starting OIDC login flow...") + + creds := &oauthex.ClientCredentials{ + ClientID: *idpClientID, + } + if *idpClientSecret != "" { + creds.ClientSecretAuth = &oauthex.ClientSecretAuth{ + ClientSecret: *idpClientSecret, + } + } + + oidcConfig := &extauth.OIDCLoginConfig{ + IssuerURL: *idpIssuerURL, + Credentials: creds, + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + Scopes: []string{"openid", "profile", "email"}, + } + + // PerformOIDCLogin handles the complete OIDC Authorization Code flow with PKCE. + tokens, err := extauth.PerformOIDCLogin(ctx, oidcConfig, receiver.getAuthorizationCode) + if err != nil { + return nil, fmt.Errorf("OIDC login failed: %w", err) + } + + log.Println("OIDC login successful, obtained ID token") + return tokens, nil + } + + // Create the Enterprise Handler. + // This handler implements the complete Enterprise Managed Authorization flow: + // 1. OIDC Login: User authenticates with enterprise IdP → ID Token (via idTokenFetcher). + // 2. Token Exchange (RFC 8693): ID Token → ID-JAG at IdP. + // 3. JWT Bearer Grant (RFC 7523): ID-JAG → Access Token at MCP Server. + log.Println("Creating enterprise authorization handler...") + + // Prepare IdP credentials + idpCreds := &oauthex.ClientCredentials{ + ClientID: *idpClientID, + } + if *idpClientSecret != "" { + idpCreds.ClientSecretAuth = &oauthex.ClientSecretAuth{ + ClientSecret: *idpClientSecret, + } + } + + // Prepare MCP credentials + var mcpCreds *oauthex.ClientCredentials + if *mcpClientID != "" { + mcpCreds = &oauthex.ClientCredentials{ + ClientID: *mcpClientID, + } + if *mcpClientSecret != "" { + mcpCreds.ClientSecretAuth = &oauthex.ClientSecretAuth{ + ClientSecret: *mcpClientSecret, + } + } + } + + enterpriseHandler, err := extauth.NewEnterpriseHandler(&extauth.EnterpriseHandlerConfig{ + // IdP configuration (where the user authenticates). + IdPIssuerURL: *idpIssuerURL, + IdPCredentials: idpCreds, + + // MCP Server configuration (the resource being accessed). + MCPAuthServerURL: *mcpAuthServerURL, + MCPResourceURI: *mcpResourceURI, + MCPCredentials: mcpCreds, + MCPScopes: []string{"read", "write"}, + + // ID Token fetcher (performs OIDC login when needed). + IDTokenFetcher: idTokenFetcher, + }) + if err != nil { + log.Fatalf("failed to create enterprise handler: %v", err) + } + + // Create the MCP client transport with the enterprise handler. + transport := &mcp.StreamableClientTransport{ + Endpoint: *mcpServerURL, + OAuthHandler: enterpriseHandler, + } + + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{ + Name: "enterprise-client-example", + Version: "1.0.0", + }, nil) + + log.Printf("Connecting to MCP server at %s...", *mcpServerURL) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + log.Fatalf("failed to connect to MCP server: %v", err) + } + defer session.Close() + + log.Println("Successfully connected to MCP server!") + + // List available tools as a demonstration. + tools, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("failed to list tools: %v", err) + } + + log.Println("\nAvailable tools:") + if len(tools.Tools) == 0 { + log.Println(" (no tools available)") + } else { + for _, tool := range tools.Tools { + log.Printf(" - %q: %s", tool.Name, tool.Description) + } + } +} diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 5758e032..5de26d95 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -279,8 +279,77 @@ client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, session, err := client.Connect(ctx, transport, nil) ``` -The `auth.AuthorizationCodeHandler` automatically manages token refreshing -and step-up authentication (when the server returns `insufficient_scope` error). +The `auth.AuthorizationCodeHandler` automatically manages token refreshing (if the server provides a refresh token) and step-up authentication (when the server returns `insufficient_scope` error). + +#### Enterprise Managed Authorization (SEP-990) + +For enterprise SSO scenarios where users authenticate with an enterprise Identity Provider (IdP), +the SDK provides +[`extauth.EnterpriseHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth/extauth#EnterpriseHandler), +an implementation of `OAuthHandler` that automates the Enterprise Managed Authorization flow: + +1. **OIDC Login**: User authenticates with enterprise IdP → ID Token +2. **Token Exchange** (RFC 8693): ID Token → ID-JAG at IdP +3. **JWT Bearer Grant** (RFC 7523): ID-JAG → Access Token at MCP Server + +To use enterprise managed authorization, create an `EnterpriseHandler` and assign it to your transport: + +```go +// Create ID token fetcher using OIDC login +idTokenFetcher := func(ctx context.Context) (*oauth2.Token, error) { + oidcConfig := &extauth.OIDCLoginConfig{ + IssuerURL: "https://company.okta.com", + Credentials: &oauthex.ClientCredentials{ + ClientID: "idp-client-id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "idp-client-secret", + }, + }, + RedirectURL: "http://localhost:3142", + Scopes: []string{"openid", "profile", "email"}, + } + + tokens, err := extauth.PerformOIDCLogin(ctx, oidcConfig, authCodeFetcher) + if err != nil { + return nil, err + } + + return tokens, nil +} + +// Create Enterprise Handler +enterpriseHandler, err := extauth.NewEnterpriseHandler(&extauth.EnterpriseHandlerConfig{ + IdPIssuerURL: "https://company.okta.com", + IdPCredentials: &oauthex.ClientCredentials{ + ClientID: "idp-client-id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "idp-client-secret", + }, + }, + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURI: "https://mcp.mcpserver.example", + MCPCredentials: &oauthex.ClientCredentials{ + ClientID: "mcp-client-id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "mcp-client-secret", + }, + }, + MCPScopes: []string{"read", "write"}, + IDTokenFetcher: idTokenFetcher, +}) + +// Use with transport +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: enterpriseHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `EnterpriseHandler` automatically manages the token exchange flow. Note that it intentionally does not support refresh tokens - when an access token expires, the entire authorization flow is repeated to ensure enterprise policies are consistently enforced. + +For a complete working example, see [examples/auth/enterprise](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/auth/enterprise). ## Security @@ -397,3 +466,4 @@ or Issue #460 discusses some potential ergonomic improvements to this API. %include ../../mcp/mcp_example_test.go progress - + diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 36210576..da21ca54 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -185,7 +185,7 @@ func validateAuthServerMetaURLs(asm *AuthServerMeta) error { } for _, u := range urls { - if err := checkURLScheme(u.value); err != nil { + if err := CheckURLScheme(u.value); err != nil { return fmt.Errorf("%s: %w", u.name, err) } } diff --git a/oauthex/client.go b/oauthex/client.go new file mode 100644 index 00000000..8fd2f763 --- /dev/null +++ b/oauthex/client.go @@ -0,0 +1,59 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import "errors" + +// ClientCredentials holds client authentication credentials for OAuth token requests. +// It supports multiple authentication methods, but only one method should be set at a time. +// Use the Validate method to ensure proper configuration. +type ClientCredentials struct { + // ClientID is the OAuth2 client identifier. + // REQUIRED for all authentication methods. + ClientID string + + // ClientSecretAuth configures client authentication using a client secret. + // This is the most common authentication method for confidential clients. + // OPTIONAL. If not provided, the client is treated as a public client. + ClientSecretAuth *ClientSecretAuth +} + +// ClientSecretAuth holds client secret authentication credentials. +// This authentication method supports both "client_secret_basic" and "client_secret_post" +// methods as defined in RFC 6749 Section 2.3.1. +type ClientSecretAuth struct { + // ClientSecret is the OAuth2 client secret for confidential clients. + // REQUIRED when using ClientSecretAuth. + ClientSecret string +} + +// Validate checks that the ClientCredentials are properly configured. +// It ensures that: +// - ClientID is not empty. +// - At most one authentication method is configured. +// - If ClientSecretAuth is set, ClientSecret is not empty. +func (c *ClientCredentials) Validate() error { + if c.ClientID == "" { + return errors.New("ClientID is required") + } + + // Count how many auth methods are configured. + authMethodCount := 0 + if c.ClientSecretAuth != nil { + authMethodCount++ + if c.ClientSecretAuth.ClientSecret == "" { + return errors.New("ClientSecret is required when using ClientSecretAuth") + } + } + + // Allow zero auth methods (public client) or exactly one auth method. + if authMethodCount > 1 { + return errors.New("only one client authentication method can be configured") + } + + return nil +} diff --git a/oauthex/dcr.go b/oauthex/dcr.go index 6db30255..3159a07f 100644 --- a/oauthex/dcr.go +++ b/oauthex/dcr.go @@ -237,7 +237,7 @@ func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { // Validate redirect URIs for i, uri := range meta.RedirectURIs { - if err := checkURLScheme(uri); err != nil { + if err := CheckURLScheme(uri); err != nil { return fmt.Errorf("redirect_uris[%d]: %w", i, err) } } @@ -255,7 +255,7 @@ func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { } for _, u := range urls { - if err := checkURLScheme(u.value); err != nil { + if err := CheckURLScheme(u.value); err != nil { return fmt.Errorf("%s: %w", u.name, err) } } diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index d8aeb3c2..39a91f35 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -63,10 +63,10 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 return &t, nil } -// checkURLScheme ensures that its argument is a valid URL with a scheme +// CheckURLScheme ensures that its argument is a valid URL with a scheme // that prevents XSS attacks. // See #526. -func checkURLScheme(u string) error { +func CheckURLScheme(u string) error { if u == "" { return nil } diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index 4680c153..43557101 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -112,7 +112,7 @@ func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL } // Validate the authorization server URLs to prevent XSS attacks (see #526). for i, u := range prm.AuthorizationServers { - if err := checkURLScheme(u); err != nil { + if err := CheckURLScheme(u); err != nil { return nil, fmt.Errorf("authorization_servers[%d]: %v", i, err) } if err := checkHTTPSOrLoopback(u); err != nil { diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go new file mode 100644 index 00000000..aaab02a8 --- /dev/null +++ b/oauthex/token_exchange.go @@ -0,0 +1,185 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Token Exchange (RFC 8693) for Enterprise Managed Authorization. +// See https://datatracker.ietf.org/doc/html/rfc8693 + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "fmt" + "net/http" + "strings" + + "golang.org/x/oauth2" +) + +// Token type identifiers defined by RFC 8693 and SEP-990. +const ( + // TokenTypeIDToken is the URN for OpenID Connect ID Tokens. + TokenTypeIDToken = "urn:ietf:params:oauth:token-type:id_token" + + // TokenTypeSAML2 is the URN for SAML 2.0 assertions. + TokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2" + + // TokenTypeIDJAG is the URN for Identity Assertion JWT Authorization Grants. + // This is the token type returned by IdP during token exchange for SEP-990. + TokenTypeIDJAG = "urn:ietf:params:oauth:token-type:id-jag" + + // GrantTypeTokenExchange is the grant type for RFC 8693 token exchange. + GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" +) + +// TokenExchangeRequest represents a Token Exchange request per RFC 8693. +// This is used for Enterprise Managed Authorization (SEP-990) where an MCP Client +// exchanges an ID Token from an enterprise IdP for an ID-JAG that can be used +// to obtain an access token from an MCP Server's authorization server. +type TokenExchangeRequest struct { + // RequestedTokenType indicates the type of security token being requested. + // For SEP-990, this MUST be TokenTypeIDJAG. + RequestedTokenType string + + // Audience is the logical name of the target service where the client + // intends to use the requested token. For SEP-990, this MUST be the + // Issuer URL of the MCP Server's authorization server. + Audience string + + // Resource is the physical location or identifier of the target resource. + // For SEP-990, this MUST be the RFC9728 Resource Identifier of the MCP Server. + Resource string + + // Scope is a list of space-separated scopes for the requested token. + // This is OPTIONAL per RFC 8693 but commonly used in SEP-990. + Scope []string + + // SubjectToken is the security token that represents the identity of the + // party on behalf of whom the request is being made. For SEP-990, this is + // typically an OpenID Connect ID Token. + SubjectToken string + + // SubjectTokenType is the type of the security token in SubjectToken. + // For SEP-990 with OIDC, this MUST be TokenTypeIDToken. + SubjectTokenType string +} + +// ExchangeToken performs a token exchange request per RFC 8693 for Enterprise +// Managed Authorization (SEP-990). It exchanges an identity assertion (typically +// an ID Token) for an Identity Assertion JWT Authorization Grant (ID-JAG) that +// can be used to obtain an access token from an MCP Server. +// +// The tokenEndpoint parameter should be the IdP's token endpoint (typically +// obtained from the IdP's authorization server metadata). +// +// Returns an oauth2.Token where: +// - Extra("issued_token_type") contains the type of the issued token (e.g., TokenTypeIDJAG) +// - AccessToken contains the ID-JAG JWT (despite the name, this is not an OAuth access token) +// - TokenType is typically "N_A" for SEP-990 +// - Extra("scope") may contain the scope if different from the request +// - Expiry is when the token expires +func ExchangeToken( + ctx context.Context, + tokenEndpoint string, + req *TokenExchangeRequest, + clientCreds *ClientCredentials, + httpClient *http.Client, +) (*oauth2.Token, error) { + if tokenEndpoint == "" { + return nil, fmt.Errorf("token endpoint is required") + } + if req == nil { + return nil, fmt.Errorf("token exchange request is required") + } + if clientCreds == nil { + return nil, fmt.Errorf("client credentials are required") + } + if err := clientCreds.Validate(); err != nil { + return nil, fmt.Errorf("invalid client credentials: %w", err) + } + + // Validate required fields per SEP-990 Section 4. + if req.RequestedTokenType == "" { + return nil, fmt.Errorf("requested_token_type is required") + } + if req.Audience == "" { + return nil, fmt.Errorf("audience is required") + } + if req.Resource == "" { + return nil, fmt.Errorf("resource is required") + } + if req.SubjectToken == "" { + return nil, fmt.Errorf("subject_token is required") + } + if req.SubjectTokenType == "" { + return nil, fmt.Errorf("subject_token_type is required") + } + + // Validate URL schemes to prevent XSS attacks (see #526). + if err := CheckURLScheme(tokenEndpoint); err != nil { + return nil, fmt.Errorf("invalid token endpoint: %w", err) + } + if err := CheckURLScheme(req.Audience); err != nil { + return nil, fmt.Errorf("invalid audience: %w", err) + } + if err := CheckURLScheme(req.Resource); err != nil { + return nil, fmt.Errorf("invalid resource: %w", err) + } + + // Per RFC 6749 Section 3.2, parameters sent without a value (like the empty + // "code" parameter) MUST be treated as if they were omitted from the request. + // The oauth2 library's Exchange method sends an empty code, but compliant + // servers should ignore it. + cfg := &oauth2.Config{ + ClientID: clientCreds.ClientID, + Endpoint: oauth2.Endpoint{ + TokenURL: tokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + } + // Set ClientSecret if ClientSecretAuth is configured. + if clientCreds.ClientSecretAuth != nil { + cfg.ClientSecret = clientCreds.ClientSecretAuth.ClientSecret + } + + // Use custom HTTP client if provided. + if httpClient == nil { + httpClient = http.DefaultClient + } + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + // Build token exchange parameters per RFC 8693. + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("grant_type", GrantTypeTokenExchange), + oauth2.SetAuthURLParam("requested_token_type", req.RequestedTokenType), + oauth2.SetAuthURLParam("audience", req.Audience), + oauth2.SetAuthURLParam("resource", req.Resource), + oauth2.SetAuthURLParam("subject_token", req.SubjectToken), + oauth2.SetAuthURLParam("subject_token_type", req.SubjectTokenType), + } + if len(req.Scope) > 0 { + opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(req.Scope, " "))) + } + + // Exchange with token exchange grant type. + // SetAuthURLParam overrides the default grant_type and adds all required parameters. + token, err := cfg.Exchange( + ctxWithClient, + "", // empty code - per RFC 6749 Section 3.2, empty params should be ignored + opts..., + ) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + + // Validate that issued_token_type is present in the response. + // The oauth2 library stores additional response fields in Extra. + issuedTokenType, _ := token.Extra("issued_token_type").(string) + if issuedTokenType == "" { + return nil, fmt.Errorf("response missing required field: issued_token_type") + } + + return token, nil +} diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go new file mode 100644 index 00000000..d5a8a16e --- /dev/null +++ b/oauthex/token_exchange_test.go @@ -0,0 +1,239 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +// TestExchangeToken tests the basic token exchange flow. +func TestExchangeToken(t *testing.T) { + // Create a test IdP server that implements token exchange + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and content type + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + contentType := r.Header.Get("Content-Type") + if contentType != "application/x-www-form-urlencoded" { + t.Errorf("expected application/x-www-form-urlencoded, got %s", contentType) + http.Error(w, "invalid content type", http.StatusBadRequest) + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + + // Verify required parameters per SEP-990 Section 4 + grantType := r.FormValue("grant_type") + if grantType != GrantTypeTokenExchange { + t.Errorf("expected grant_type %s, got %s", GrantTypeTokenExchange, grantType) + writeErrorResponse(w, "invalid_grant", "invalid grant_type") + return + } + + requestedTokenType := r.FormValue("requested_token_type") + if requestedTokenType != TokenTypeIDJAG { + t.Errorf("expected requested_token_type %s, got %s", TokenTypeIDJAG, requestedTokenType) + writeErrorResponse(w, "invalid_request", "invalid requested_token_type") + return + } + + audience := r.FormValue("audience") + if audience == "" { + t.Error("audience is required") + writeErrorResponse(w, "invalid_request", "missing audience") + return + } + + resource := r.FormValue("resource") + if resource == "" { + t.Error("resource is required") + writeErrorResponse(w, "invalid_request", "missing resource") + return + } + + subjectToken := r.FormValue("subject_token") + if subjectToken == "" { + t.Error("subject_token is required") + writeErrorResponse(w, "invalid_request", "missing subject_token") + return + } + + subjectTokenType := r.FormValue("subject_token_type") + if subjectTokenType != TokenTypeIDToken { + t.Errorf("expected subject_token_type %s, got %s", TokenTypeIDToken, subjectTokenType) + writeErrorResponse(w, "invalid_request", "invalid subject_token_type") + return + } + + // Verify client authentication + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + if clientID == "" || clientSecret == "" { + t.Error("client authentication required") + writeErrorResponse(w, "invalid_client", "client authentication failed") + return + } + + if clientID != "test-client-id" || clientSecret != "test-client-secret" { + t.Error("invalid client credentials") + writeErrorResponse(w, "invalid_client", "invalid credentials") + return + } + + // Return successful token exchange response per SEP-990 Section 4.2 + resp := map[string]interface{}{ + "issued_token_type": TokenTypeIDJAG, + "access_token": "fake-id-jag-token", + "token_type": "N_A", + "scope": r.FormValue("scope"), + "expires_in": 300, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + // Test successful token exchange + t.Run("successful exchange", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Audience: "https://auth.mcpserver.example/", + Resource: "https://mcp.mcpserver.example/", + Scope: []string{"read", "write"}, + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + token, err := ExchangeToken( + context.Background(), + server.URL, + req, + &ClientCredentials{ + ClientID: "test-client-id", + ClientSecretAuth: &ClientSecretAuth{ + ClientSecret: "test-client-secret", + }, + }, + server.Client(), + ) + + if err != nil { + t.Fatalf("ExchangeToken failed: %v", err) + } + + issuedTokenType, ok := token.Extra("issued_token_type").(string) + if !ok || issuedTokenType != TokenTypeIDJAG { + t.Errorf("expected issued_token_type %s, got %s", TokenTypeIDJAG, issuedTokenType) + } + + if token.AccessToken != "fake-id-jag-token" { + t.Errorf("expected access_token 'fake-id-jag-token', got %s", token.AccessToken) + } + + if token.TokenType != "N_A" { + t.Errorf("expected token_type 'N_A', got %s", token.TokenType) + } + + scope, ok := token.Extra("scope").(string) + if !ok || scope != "read write" { + t.Errorf("expected scope 'read write', got %s", scope) + } + + // expires_in should be available in Extra + expiresIn, ok := token.Extra("expires_in").(float64) + if !ok || int(expiresIn) != 300 { + t.Errorf("expected expires_in 300, got %v", expiresIn) + } + }) + + // Test missing required fields + t.Run("missing audience", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Resource: "https://mcp.mcpserver.example/", + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + _, err := ExchangeToken( + context.Background(), + server.URL, + req, + &ClientCredentials{ + ClientID: "test-client-id", + ClientSecretAuth: &ClientSecretAuth{ + ClientSecret: "test-client-secret", + }, + }, + server.Client(), + ) + + if err == nil { + t.Error("expected error for missing audience, got nil") + } + }) + + // Test invalid URL schemes + t.Run("invalid audience URL scheme", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Audience: "javascript:alert(1)", + Resource: "https://mcp.mcpserver.example/", + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + _, err := ExchangeToken( + context.Background(), + server.URL, + req, + &ClientCredentials{ + ClientID: "test-client-id", + ClientSecretAuth: &ClientSecretAuth{ + ClientSecret: "test-client-secret", + }, + }, + server.Client(), + ) + + if err == nil { + t.Error("expected error for invalid audience URL scheme, got nil") + } + }) +} + +// writeErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. +func writeErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { + errResp := struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + }{ + Error: errorCode, + ErrorDescription: errorDescription, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(errResp) +} diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go index 83eeb5e1..c5221a62 100644 --- a/oauthex/url_scheme_test.go +++ b/oauthex/url_scheme_test.go @@ -15,7 +15,7 @@ import ( "testing" ) -// TestCheckURLScheme tests the checkURLScheme function directly. +// TestCheckURLScheme tests the CheckURLScheme function directly. func TestCheckURLScheme(t *testing.T) { tests := []struct { name string @@ -40,9 +40,9 @@ func TestCheckURLScheme(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := checkURLScheme(tt.url) + err := CheckURLScheme(tt.url) if (err != nil) != tt.wantErr { - t.Errorf("checkURLScheme(%q): got err %v, want err %v", tt.url, err != nil, tt.wantErr) + t.Errorf("CheckURLScheme(%q): got err %v, want err %v", tt.url, err != nil, tt.wantErr) } }) }