diff --git a/README.md b/README.md index a12f1531ac..e957d6c2d8 100644 --- a/README.md +++ b/README.md @@ -974,6 +974,12 @@ The following sets of tools are available: organization Organizations +- **check_org_membership** - Check organization membership + - **Required OAuth Scopes**: `read:org` + - **Accepted OAuth Scopes**: `admin:org`, `read:org`, `write:org` + - `org`: GitHub organization login (string, required) + - `username`: GitHub username to check (string, required) + - **search_orgs** - Search organizations - **Required OAuth Scopes**: `read:org` - **Accepted OAuth Scopes**: `admin:org`, `read:org`, `write:org` @@ -1502,6 +1508,18 @@ Following tools will filter out content from users lacking the push access: - `pull_request_read:get_review_comments` - `pull_request_read:get_reviews` +## Pull Request Author Allowlist + +To restrict mutating pull request tools to bot-authored PRs, use `--allowed-pr-authors` or `GITHUB_ALLOWED_PR_AUTHORS` with a comma-separated list of GitHub logins: + +```bash +GITHUB_ALLOWED_PR_AUTHORS='renovate[bot],github-actions[bot]' ./github-mcp-server stdio --toolsets=pull_requests,actions +``` + +When set, tools such as `merge_pull_request`, `update_pull_request`, review-write tools, and PR branch updates fetch the target PR and reject the call unless `pr.User.Login` is in the allowlist. Read-only PR tools and `create_pull_request` are not restricted. `actions_run_trigger` is not gated by this setting because it targets a ref rather than a PR number. + +In HTTP mode, `GITHUB_PERSONAL_ACCESS_TOKEN` can also be used as a server-side default token for trusted local deployments. Requests with an `Authorization` header still use the request token; requests without one fall back to the configured server token. This means the server's GitHub identity is used for any unauthenticated HTTP request, so only enable this when the HTTP endpoint is on a trusted network. + ## i18n / Overriding Descriptions The descriptions of the tools can be overridden by creating a diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 8f2ae58525..0117c83753 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -69,6 +69,13 @@ var ( } } + var allowedPRAuthors []string + if viper.IsSet("allowed_pr_authors") { + if err := viper.UnmarshalKey("allowed_pr_authors", &allowedPRAuthors); err != nil { + return fmt.Errorf("failed to unmarshal allowed-pr-authors: %w", err) + } + } + // Parse enabled features (similar to toolsets) var enabledFeatures []string if viper.IsSet("features") { @@ -92,6 +99,7 @@ var ( LogFilePath: viper.GetString("log-file"), ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), + AllowedPRAuthors: allowedPRAuthors, InsidersMode: viper.GetBool("insiders"), ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, @@ -127,10 +135,18 @@ var ( } } + var allowedPRAuthors []string + if viper.IsSet("allowed_pr_authors") { + if err := viper.UnmarshalKey("allowed_pr_authors", &allowedPRAuthors); err != nil { + return fmt.Errorf("failed to unmarshal allowed-pr-authors: %w", err) + } + } + ttl := viper.GetDuration("repo-access-cache-ttl") httpConfig := ghhttp.ServerConfig{ Version: version, Host: viper.GetString("host"), + Token: viper.GetString("personal_access_token"), Port: viper.GetInt("port"), BaseURL: viper.GetString("base-url"), ResourcePath: viper.GetString("base-path"), @@ -139,6 +155,7 @@ var ( LogFilePath: viper.GetString("log-file"), ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), + AllowedPRAuthors: allowedPRAuthors, RepoAccessCacheTTL: &ttl, ScopeChallenge: viper.GetBool("scope-challenge"), ReadOnly: viper.GetBool("read-only"), @@ -173,6 +190,7 @@ func init() { rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)") rootCmd.PersistentFlags().Int("content-window-size", 5000, "Specify the content window size") rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode") + rootCmd.PersistentFlags().StringSlice("allowed-pr-authors", nil, "Comma-separated list of pull request author logins allowed for mutating pull request tools") rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") @@ -195,6 +213,7 @@ func init() { _ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host")) _ = viper.BindPFlag("content-window-size", rootCmd.PersistentFlags().Lookup("content-window-size")) _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) + _ = viper.BindPFlag("allowed_pr_authors", rootCmd.PersistentFlags().Lookup("allowed-pr-authors")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) diff --git a/docs/server-configuration.md b/docs/server-configuration.md index 693c096a1b..26aaedccb1 100644 --- a/docs/server-configuration.md +++ b/docs/server-configuration.md @@ -10,9 +10,11 @@ We currently support the following ways in which the GitHub MCP Server can be co | Toolsets | `X-MCP-Toolsets` header or `/x/{toolset}` URL | `--toolsets` flag or `GITHUB_TOOLSETS` env var | | Individual Tools | `X-MCP-Tools` header | `--tools` flag or `GITHUB_TOOLS` env var | | Exclude Tools | `X-MCP-Exclude-Tools` header | `--exclude-tools` flag or `GITHUB_EXCLUDE_TOOLS` env var | +| Organization Membership Lookup | Enable `orgs` toolset or `check_org_membership` tool | Enable `orgs` toolset or `check_org_membership` tool | | Read-Only Mode | `X-MCP-Readonly` header or `/readonly` URL | `--read-only` flag or `GITHUB_READ_ONLY` env var | | Dynamic Mode | Not available | `--dynamic-toolsets` flag or `GITHUB_DYNAMIC_TOOLSETS` env var | | Lockdown Mode | `X-MCP-Lockdown` header | `--lockdown-mode` flag or `GITHUB_LOCKDOWN_MODE` env var | +| PR Author Allowlist | Server `--allowed-pr-authors` flag or `GITHUB_ALLOWED_PR_AUTHORS` env var | `--allowed-pr-authors` flag or `GITHUB_ALLOWED_PR_AUTHORS` env var | | Insiders Mode | `X-MCP-Insiders` header or `/insiders` URL | `--insiders` flag or `GITHUB_INSIDERS` env var | | Feature Flags | `X-MCP-Features` header | `--features` flag | | Scope Filtering | Always enabled | Always enabled | @@ -30,6 +32,8 @@ Note: **read-only** mode acts as a strict security filter that takes precedence Note: **excluded tools** takes precedence over toolsets and individual tools — listed tools are always excluded, even if their toolset is enabled or they are explicitly added via `--tools` / `X-MCP-Tools`. +Note: **PR author allowlist** restricts mutating pull request tools to existing pull requests authored by the configured GitHub logins. Read-only PR tools and `create_pull_request` are not restricted. `actions_run_trigger` is not restricted by this setting because it targets a ref rather than a pull request number. + --- ## Configuration Examples @@ -387,6 +391,33 @@ Lockdown mode ensures the server only surfaces content in public repositories fr --- +### PR Author Allowlist + +**Best for:** Automation workflows that may mutate bot-authored pull requests but should never mutate human-authored pull requests. + +When set, mutating pull request tools first fetch the target pull request and check `pr.User.Login`. If the author is not in the allowlist, the tool returns an error before making the mutation. Empty or unset means unrestricted behavior. + +```json +{ + "type": "stdio", + "command": "go", + "args": [ + "run", + "./cmd/github-mcp-server", + "stdio", + "--toolsets=pull_requests,actions", + "--allowed-pr-authors=renovate[bot],github-actions[bot]" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "${input:github_token}" + } +} +``` + +Known limitations: `actions_run_trigger` operates on refs, not pull request numbers, so it is not gated by this setting. Review-thread resolve and unresolve tools take only opaque thread IDs and are not gated by the PR author allowlist. The allowlist checks `pr.User.Login`; PRs from forks authored by allowed bots still pass. Enabling the allowlist adds one API call before a mutating PR operation when the handler does not already have the pull request. + +--- + ### Insiders Mode **Best for:** Users who want early access to experimental features and new tools before they reach general availability. diff --git a/docs/streamable-http.md b/docs/streamable-http.md index 0a11c5ea76..a26060df52 100644 --- a/docs/streamable-http.md +++ b/docs/streamable-http.md @@ -91,3 +91,15 @@ To provide PAT credentials, or to customize server behavior preferences, you can ``` See [Remote Server](./remote-server.md) documentation for more details on client configuration options. + +### Using a Server-Side Default Token + +For trusted local deployments, HTTP mode can use `GITHUB_PERSONAL_ACCESS_TOKEN` as a fallback when a request does not include an `Authorization` header: + +```bash +GITHUB_PERSONAL_ACCESS_TOKEN=ghp_yourtokenhere github-mcp-server http +``` + +If a request includes `Authorization: Bearer ...`, that request token takes precedence. If no request token is provided and no server-side token is configured, the server returns `401 Unauthorized`. + +When this fallback is enabled, the server's GitHub identity is used for every HTTP request without an `Authorization` header. Only expose the endpoint on a trusted network. diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index b1925bffd3..7a5befd566 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -135,6 +135,7 @@ func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig) (*mcp.Se cfg.ContentWindowSize, featureChecker, obs, + cfg.AllowedPRAuthors, ) // Build and register the tool/resource/prompt inventory inventoryBuilder := github.NewInventory(cfg.Translator). @@ -220,6 +221,10 @@ type StdioServerConfig struct { // LockdownMode indicates if we should enable lockdown mode LockdownMode bool + // AllowedPRAuthors restricts mutating pull request tools to PRs authored by + // one of these GitHub logins. Empty means unrestricted. + AllowedPRAuthors []string + // InsidersMode indicates if we should enable experimental features InsidersMode bool @@ -255,6 +260,9 @@ func RunStdioServer(cfg StdioServerConfig) error { } logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) + if len(cfg.AllowedPRAuthors) > 0 { + logger.Info("PR author allowlist enforced", "authors", cfg.AllowedPRAuthors) + } // Fetch token scopes for scope-based tool filtering (PAT tokens only) // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. @@ -284,6 +292,7 @@ func RunStdioServer(cfg StdioServerConfig) error { Translator: t, ContentWindowSize: cfg.ContentWindowSize, LockdownMode: cfg.LockdownMode, + AllowedPRAuthors: cfg.AllowedPRAuthors, InsidersMode: cfg.InsidersMode, ExcludeTools: cfg.ExcludeTools, Logger: logger, diff --git a/pkg/github/__toolsnaps__/check_org_membership.snap b/pkg/github/__toolsnaps__/check_org_membership.snap new file mode 100644 index 0000000000..709222b070 --- /dev/null +++ b/pkg/github/__toolsnaps__/check_org_membership.snap @@ -0,0 +1,25 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "Check organization membership" + }, + "description": "Check whether a GitHub user is a member of an organization, and report whether that membership is public, private, or not visible.", + "inputSchema": { + "properties": { + "org": { + "description": "GitHub organization login", + "type": "string" + }, + "username": { + "description": "GitHub username to check", + "type": "string" + } + }, + "required": [ + "org", + "username" + ], + "type": "object" + }, + "name": "check_org_membership" +} \ No newline at end of file diff --git a/pkg/github/copilot.go b/pkg/github/copilot.go index d95357e738..33543e017a 100644 --- a/pkg/github/copilot.go +++ b/pkg/github/copilot.go @@ -507,6 +507,10 @@ func RequestCopilotReview(t translations.TranslationHelperFunc) inventory.Server return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil diff --git a/pkg/github/copilot_test.go b/pkg/github/copilot_test.go index 0a1d5ef3b6..6f21d835c0 100644 --- a/pkg/github/copilot_test.go +++ b/pkg/github/copilot_test.go @@ -961,3 +961,31 @@ func Test_RequestCopilotReview(t *testing.T) { }) } } + +func Test_RequestCopilotReview_PRAuthorDenied(t *testing.T) { + serverTool := RequestCopilotReview(translations.NullTranslationHelper) + client := github.NewClient(MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, &github.PullRequest{ + User: &github.User{Login: github.Ptr("alice")}, + }), + PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber: func(w http.ResponseWriter, _ *http.Request) { + t.Fatal("reviewer request endpoint should not be called when PR author is denied") + }, + })) + deps := BaseDeps{ + Client: client, + allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"}), + } + handler := serverTool.Handler(deps) + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }) + + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.True(t, result.IsError) + assert.Contains(t, getErrorResult(t, result).Text, `pull request author "alice" is not in --allowed-pr-authors`) +} diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index aad213e4e5..1c6f46556e 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -105,6 +105,10 @@ type ToolDependencies interface { // Metrics returns the metrics client Metrics(ctx context.Context) metrics.Metrics + + // IsPRAuthorAllowed checks whether a pull request author is allowed for + // mutating pull request tools. enforced is false when no allowlist is set. + IsPRAuthorAllowed(login string) (allowed bool, enforced bool) } // BaseDeps is the standard implementation of ToolDependencies for the local server. @@ -127,6 +131,8 @@ type BaseDeps struct { // Observability exporters (includes logger) Obsv observability.Exporters + + allowedPRAuthors map[string]struct{} } // Compile-time assertion to verify that BaseDeps implements the ToolDependencies interface. @@ -143,6 +149,7 @@ func NewBaseDeps( contentWindowSize int, featureChecker inventory.FeatureFlagChecker, obsv observability.Exporters, + allowedPRAuthors ...[]string, ) *BaseDeps { return &BaseDeps{ Client: client, @@ -154,6 +161,7 @@ func NewBaseDeps( ContentWindowSize: contentWindowSize, featureChecker: featureChecker, Obsv: obsv, + allowedPRAuthors: buildPRAuthorAllowlist(firstStringSlice(allowedPRAuthors)), } } @@ -196,6 +204,11 @@ func (d BaseDeps) Metrics(ctx context.Context) metrics.Metrics { return d.Obsv.Metrics(ctx) } +// IsPRAuthorAllowed implements ToolDependencies. +func (d BaseDeps) IsPRAuthorAllowed(login string) (bool, bool) { + return isPRAuthorAllowed(d.allowedPRAuthors, login) +} + // IsFeatureEnabled checks if a feature flag is enabled. // Returns false if the feature checker is nil, flag name is empty, or an error occurs. // This allows tools to conditionally change behavior based on feature flags. @@ -276,6 +289,8 @@ type RequestDeps struct { // Observability exporters (includes logger) obsv observability.Exporters + + allowedPRAuthors map[string]struct{} } // NewRequestDeps creates a RequestDeps with the provided clients and configuration. @@ -288,6 +303,7 @@ func NewRequestDeps( contentWindowSize int, featureChecker inventory.FeatureFlagChecker, obsv observability.Exporters, + allowedPRAuthors ...[]string, ) *RequestDeps { return &RequestDeps{ apiHosts: apiHosts, @@ -298,6 +314,7 @@ func NewRequestDeps( ContentWindowSize: contentWindowSize, featureChecker: featureChecker, obsv: obsv, + allowedPRAuthors: buildPRAuthorAllowlist(firstStringSlice(allowedPRAuthors)), } } @@ -420,6 +437,11 @@ func (d *RequestDeps) Metrics(ctx context.Context) metrics.Metrics { return d.obsv.Metrics(ctx) } +// IsPRAuthorAllowed implements ToolDependencies. +func (d *RequestDeps) IsPRAuthorAllowed(login string) (bool, bool) { + return isPRAuthorAllowed(d.allowedPRAuthors, login) +} + // IsFeatureEnabled checks if a feature flag is enabled. func (d *RequestDeps) IsFeatureEnabled(ctx context.Context, flagName string) bool { if d.featureChecker == nil || flagName == "" { diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 67a05fd6c0..ff5dc71375 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -140,6 +140,11 @@ const ( GetSearchUsers = "GET /search/users" GetSearchRepositories = "GET /search/repositories" + // Organization endpoints + GetOrgsByOrg = "GET /orgs/{org}" + GetOrgsMembersByOrgByUsername = "GET /orgs/{org}/members/{username}" + GetOrgsPublicMembersByOrgByUsername = "GET /orgs/{org}/public_members/{username}" + // Raw content endpoints (used for GitHub raw content API, not standard API) // These are used with the raw content client that interacts with raw.githubusercontent.com GetRawReposContentsByOwnerByRepoByPath = "GET /{owner}/{repo}/HEAD/{path:.*}" diff --git a/pkg/github/issues.go b/pkg/github/issues.go index e3e1f6b223..1563a1881e 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -663,6 +663,10 @@ func AddIssueComment(t translations.TranslationHelperFunc) inventory.ServerTool if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } + if result, err := enforceIssueCommentPRAuthorAllowlist(ctx, deps, client, owner, repo, issueNumber); result != nil || err != nil { + return result, nil, err + } + createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) if err != nil { return utils.NewToolResultErrorFromErr("failed to create comment", err), nil, nil diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 49ce2dde9c..6e1e70183c 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -386,6 +386,41 @@ func Test_AddIssueComment(t *testing.T) { } } +func Test_AddIssueComment_PRAuthorDenied(t *testing.T) { + serverTool := AddIssueComment(translations.NullTranslationHelper) + client := github.NewClient(MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposIssuesByOwnerByRepoByIssueNumber: mockResponse(t, http.StatusOK, &github.Issue{ + Number: github.Ptr(42), + PullRequestLinks: &github.PullRequestLinks{ + URL: github.Ptr("https://api.github.com/repos/owner/repo/pulls/42"), + }, + }), + GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, &github.PullRequest{ + User: &github.User{Login: github.Ptr("alice")}, + }), + PostReposIssuesCommentsByOwnerByRepoByIssueNumber: func(w http.ResponseWriter, _ *http.Request) { + t.Fatal("issue comment endpoint should not be called when PR author is denied") + }, + })) + deps := BaseDeps{ + Client: client, + allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"}), + } + handler := serverTool.Handler(deps) + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + "body": "comment", + }) + + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.True(t, result.IsError) + assert.Contains(t, getErrorResult(t, result).Text, `pull request author "alice" is not in --allowed-pr-authors`) +} + func Test_SearchIssues(t *testing.T) { // Verify tool definition once serverTool := SearchIssues(translations.NullTranslationHelper) diff --git a/pkg/github/orgs.go b/pkg/github/orgs.go new file mode 100644 index 0000000000..b1ec95aa22 --- /dev/null +++ b/pkg/github/orgs.go @@ -0,0 +1,141 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const orgMembershipVisibilityNote = "result reflects caller visibility; private members of orgs you can't see appear as non-members." + +type CheckOrgMembershipInput struct { + Org string `json:"org"` + Username string `json:"username"` +} + +type CheckOrgMembershipOutput struct { + Org string `json:"org"` + Username string `json:"username"` + IsMember bool `json:"isMember"` + Visibility string `json:"visibility"` + Note string `json:"note,omitempty"` +} + +// CheckOrgMembership creates a tool to check whether a GitHub user is a member of an organization. +func CheckOrgMembership(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool[CheckOrgMembershipInput, CheckOrgMembershipOutput]( + ToolsetMetadataOrgs, + mcp.Tool{ + Name: "check_org_membership", + Description: t("TOOL_CHECK_ORG_MEMBERSHIP_DESCRIPTION", "Check whether a GitHub user is a member of an organization, and report whether that membership is public, private, or not visible."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CHECK_ORG_MEMBERSHIP_USER_TITLE", "Check organization membership"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "org": { + Type: "string", + Description: "GitHub organization login", + }, + "username": { + Type: "string", + Description: "GitHub username to check", + }, + }, + Required: []string{"org", "username"}, + }, + }, + []scopes.Scope{scopes.ReadOrg}, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args CheckOrgMembershipInput) (*mcp.CallToolResult, CheckOrgMembershipOutput, error) { + if args.Org == "" { + return utils.NewToolResultError("missing required parameter: org"), CheckOrgMembershipOutput{}, nil + } + if args.Username == "" { + return utils.NewToolResultError("missing required parameter: username"), CheckOrgMembershipOutput{}, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), CheckOrgMembershipOutput{}, nil + } + + isMember, res, err := client.Organizations.IsMember(ctx, args.Org, args.Username) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to check organization membership", + res, + err, + ), CheckOrgMembershipOutput{}, nil + } + + isPublicMember, res, err := client.Organizations.IsPublicMember(ctx, args.Org, args.Username) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to check public organization membership", + res, + err, + ), CheckOrgMembershipOutput{}, nil + } + + output := CheckOrgMembershipOutput{ + Org: args.Org, + Username: args.Username, + } + switch { + case isPublicMember: + output.IsMember = true + output.Visibility = "public" + case isMember: + output.IsMember = true + output.Visibility = "private" + default: + if errResult := verifyOrganizationExists(ctx, args.Org, deps); errResult != nil { + return errResult, CheckOrgMembershipOutput{}, nil + } + output.Visibility = "none" + output.Note = orgMembershipVisibilityNote + } + + r, err := json.Marshal(output) + if err != nil { + return nil, CheckOrgMembershipOutput{}, err + } + return utils.NewToolResultText(string(r)), output, nil + }, + ) +} + +func verifyOrganizationExists(ctx context.Context, org string, deps ToolDependencies) *mcp.CallToolResult { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err) + } + + _, res, err := client.Organizations.Get(ctx, org) + if err == nil { + return nil + } + if res != nil && res.Response != nil && res.Response.StatusCode == http.StatusNotFound { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get organization", + res, + err, + ) + } + + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to verify organization exists", + res, + err, + ) +} diff --git a/pkg/github/orgs_test.go b/pkg/github/orgs_test.go new file mode 100644 index 0000000000..c5daaa4e51 --- /dev/null +++ b/pkg/github/orgs_test.go @@ -0,0 +1,142 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/internal/toolsnaps" + "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v82/github" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CheckOrgMembership(t *testing.T) { + serverTool := CheckOrgMembership(translations.NullTranslationHelper) + tool := serverTool.Tool + + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "check_org_membership", tool.Name) + assert.NotEmpty(t, tool.Description) + + schema, ok := tool.InputSchema.(*jsonschema.Schema) + require.True(t, ok, "InputSchema should be *jsonschema.Schema") + assert.Contains(t, schema.Properties, "org") + assert.Contains(t, schema.Properties, "username") + assert.ElementsMatch(t, schema.Required, []string{"org", "username"}) + assert.True(t, serverTool.IsReadOnly()) + assert.ElementsMatch(t, []string{"read:org"}, serverTool.RequiredScopes) +} + +func Test_CheckOrgMembership_PublicMember(t *testing.T) { + output := runCheckOrgMembership(t, MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsMembersByOrgByUsername: expectPath(t, "/orgs/canonical/members/octocat").andThen(mockStatus(http.StatusNoContent)), + GetOrgsPublicMembersByOrgByUsername: expectPath(t, "/orgs/canonical/public_members/octocat").andThen(mockStatus(http.StatusNoContent)), + }), false) + + assert.Equal(t, CheckOrgMembershipOutput{ + Org: "canonical", + Username: "octocat", + IsMember: true, + Visibility: "public", + }, output) +} + +func Test_CheckOrgMembership_PrivateMember(t *testing.T) { + output := runCheckOrgMembership(t, MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsMembersByOrgByUsername: expectPath(t, "/orgs/canonical/members/octocat").andThen(mockStatus(http.StatusNoContent)), + GetOrgsPublicMembersByOrgByUsername: expectPath(t, "/orgs/canonical/public_members/octocat").andThen(mockStatus(http.StatusNotFound)), + }), false) + + assert.Equal(t, CheckOrgMembershipOutput{ + Org: "canonical", + Username: "octocat", + IsMember: true, + Visibility: "private", + }, output) +} + +func Test_CheckOrgMembership_NotAMember(t *testing.T) { + output := runCheckOrgMembership(t, MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsMembersByOrgByUsername: expectPath(t, "/orgs/canonical/members/octocat").andThen(mockStatus(http.StatusNotFound)), + GetOrgsPublicMembersByOrgByUsername: expectPath(t, "/orgs/canonical/public_members/octocat").andThen(mockStatus(http.StatusNotFound)), + GetOrgsByOrg: expectPath(t, "/orgs/canonical").andThen(mockResponse(t, http.StatusOK, &gogithub.Organization{Login: gogithub.Ptr("canonical")})), + }), false) + + assert.Equal(t, "canonical", output.Org) + assert.Equal(t, "octocat", output.Username) + assert.False(t, output.IsMember) + assert.Equal(t, "none", output.Visibility) + assert.Equal(t, "result reflects caller visibility; private members of orgs you can't see appear as non-members.", output.Note) +} + +func Test_CheckOrgMembership_OrgNotFound(t *testing.T) { + text := runCheckOrgMembershipError(t, MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsMembersByOrgByUsername: expectPath(t, "/orgs/missing-org/members/octocat").andThen(mockStatus(http.StatusNotFound)), + GetOrgsPublicMembersByOrgByUsername: expectPath(t, "/orgs/missing-org/public_members/octocat").andThen(mockStatus(http.StatusNotFound)), + GetOrgsByOrg: expectPath(t, "/orgs/missing-org").andThen(mockResponse(t, http.StatusNotFound, map[string]string{"message": "Not Found"})), + }), map[string]any{"org": "missing-org", "username": "octocat"}) + + assert.Contains(t, text, "failed to get organization") + assert.Contains(t, text, "404") +} + +func Test_CheckOrgMembership_ScopeError(t *testing.T) { + for _, status := range []int{http.StatusUnauthorized, http.StatusForbidden} { + t.Run(http.StatusText(status), func(t *testing.T) { + text := runCheckOrgMembershipError(t, MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsMembersByOrgByUsername: expectPath(t, "/orgs/canonical/members/octocat").andThen( + mockResponse(t, status, map[string]string{"message": http.StatusText(status)}), + ), + }), map[string]any{"org": "canonical", "username": "octocat"}) + + assert.Contains(t, text, "failed to check organization membership") + assert.Contains(t, text, http.StatusText(status)) + }) + } +} + +func runCheckOrgMembership(t *testing.T, httpClient *http.Client, expectError bool) CheckOrgMembershipOutput { + t.Helper() + + result := callCheckOrgMembership(t, httpClient, map[string]any{"org": "canonical", "username": "octocat"}) + require.Equal(t, expectError, result.IsError) + + textContent := getTextResult(t, result) + var output CheckOrgMembershipOutput + require.NoError(t, json.Unmarshal([]byte(textContent.Text), &output)) + return output +} + +func runCheckOrgMembershipError(t *testing.T, httpClient *http.Client, args map[string]any) string { + t.Helper() + + result := callCheckOrgMembership(t, httpClient, args) + textContent := getErrorResult(t, result) + return textContent.Text +} + +func callCheckOrgMembership(t *testing.T, httpClient *http.Client, args map[string]any) *mcp.CallToolResult { + t.Helper() + + client := gogithub.NewClient(httpClient) + deps := BaseDeps{Client: client} + serverTool := CheckOrgMembership(translations.NullTranslationHelper) + handler := serverTool.Handler(deps) + request := createMCPRequest(args) + + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) + return result +} + +func mockStatus(status int) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(status) + } +} diff --git a/pkg/github/pr_author_allowlist.go b/pkg/github/pr_author_allowlist.go new file mode 100644 index 0000000000..e3f0ab70cb --- /dev/null +++ b/pkg/github/pr_author_allowlist.go @@ -0,0 +1,129 @@ +package github + +import ( + "context" + "fmt" + "strings" + + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/utils" + gogithub "github.com/google/go-github/v82/github" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func firstStringSlice(values [][]string) []string { + if len(values) == 0 { + return nil + } + return values[0] +} + +func buildPRAuthorAllowlist(authors []string) map[string]struct{} { + if len(authors) == 0 { + return nil + } + + allowlist := make(map[string]struct{}, len(authors)) + for _, author := range authors { + author = strings.TrimSpace(author) + if author == "" { + continue + } + allowlist[strings.ToLower(author)] = struct{}{} + } + if len(allowlist) == 0 { + return nil + } + return allowlist +} + +func isPRAuthorAllowed(allowlist map[string]struct{}, login string) (bool, bool) { + if len(allowlist) == 0 { + return true, false + } + _, ok := allowlist[strings.ToLower(strings.TrimSpace(login))] + return ok, true +} + +// enforcePRAuthorAllowlist returns a tool result error if an allowlist is +// configured and the PR's author is not on it. Callers that already have the +// pull request can pass it to avoid a duplicate API fetch. +func enforcePRAuthorAllowlist( + ctx context.Context, + deps ToolDependencies, + owner, repo string, + pullNumber int, + pr *gogithub.PullRequest, +) (*mcp.CallToolResult, error) { + if _, enforced := deps.IsPRAuthorAllowed(""); !enforced { + return nil, nil + } + + if pr == nil { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil + } + + var resp *gogithub.Response + pr, resp, err = client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request", resp, err), nil + } + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + } + + login := pr.GetUser().GetLogin() + if allowed, _ := deps.IsPRAuthorAllowed(login); allowed { + return nil, nil + } + + logPRAuthorAllowlistDenied(ctx, deps, owner, repo, pullNumber, login) + return utils.NewToolResultError(fmt.Sprintf("pull request author %q is not in --allowed-pr-authors", login)), nil +} + +// enforceIssueCommentPRAuthorAllowlist applies the PR author allowlist when +// an issue-comment target is actually a pull request. +func enforceIssueCommentPRAuthorAllowlist( + ctx context.Context, + deps ToolDependencies, + client *gogithub.Client, + owner, repo string, + issueNumber int, +) (*mcp.CallToolResult, error) { + if _, enforced := deps.IsPRAuthorAllowed(""); !enforced { + return nil, nil + } + + issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get issue", resp, err), nil + } + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + + if issue.GetPullRequestLinks() == nil { + return nil, nil + } + + return enforcePRAuthorAllowlist(ctx, deps, owner, repo, issueNumber, nil) +} + +func logPRAuthorAllowlistDenied(ctx context.Context, deps ToolDependencies, owner, repo string, pullNumber int, login string) { + defer func() { + _ = recover() + }() + + if logger := deps.Logger(ctx); logger != nil { + logger.Warn( + "PR mutation denied by allowlist", + "owner", owner, + "repo", repo, + "pr", pullNumber, + "author", login, + ) + } +} diff --git a/pkg/github/pr_author_allowlist_test.go b/pkg/github/pr_author_allowlist_test.go new file mode 100644 index 0000000000..716f4d1816 --- /dev/null +++ b/pkg/github/pr_author_allowlist_test.go @@ -0,0 +1,112 @@ +package github + +import ( + "context" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v82/github" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnforcePRAuthorAllowlist_NoAllowlist_PermitsAll(t *testing.T) { + deps := BaseDeps{} + + result, err := enforcePRAuthorAllowlist(context.Background(), deps, "owner", "repo", 1, nil) + + require.NoError(t, err) + require.Nil(t, result) +} + +func TestEnforcePRAuthorAllowlist_AuthorAllowed(t *testing.T) { + deps := BaseDeps{allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"})} + pr := &gogithub.PullRequest{User: &gogithub.User{Login: gogithub.Ptr("renovate[bot]")}} + + result, err := enforcePRAuthorAllowlist(context.Background(), deps, "owner", "repo", 1, pr) + + require.NoError(t, err) + require.Nil(t, result) +} + +func TestEnforcePRAuthorAllowlist_AuthorDenied(t *testing.T) { + deps := BaseDeps{allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"})} + pr := &gogithub.PullRequest{User: &gogithub.User{Login: gogithub.Ptr("alice")}} + + result, err := enforcePRAuthorAllowlist(context.Background(), deps, "owner", "repo", 1, pr) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsError) + assert.Contains(t, getErrorResult(t, result).Text, `pull request author "alice" is not in --allowed-pr-authors`) +} + +func TestEnforcePRAuthorAllowlist_FetchFailure(t *testing.T) { + client := gogithub.NewClient(MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message":"Not Found"}`)) + }, + })) + deps := BaseDeps{ + Client: client, + allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"}), + } + + result, err := enforcePRAuthorAllowlist(context.Background(), deps, "owner", "repo", 1, nil) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsError) + assert.Contains(t, getErrorResult(t, result).Text, "failed to get pull request") +} + +func TestEnforcePRAuthorAllowlist_UsesProvidedPR(t *testing.T) { + calls := 0 + client := gogithub.NewClient(MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, _ *http.Request) { + calls++ + w.WriteHeader(http.StatusInternalServerError) + }, + })) + deps := BaseDeps{ + Client: client, + allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"}), + } + pr := &gogithub.PullRequest{User: &gogithub.User{Login: gogithub.Ptr("renovate[bot]")}} + + result, err := enforcePRAuthorAllowlist(context.Background(), deps, "owner", "repo", 1, pr) + + require.NoError(t, err) + require.Nil(t, result) + assert.Zero(t, calls) +} + +func TestMergePullRequest_PRAuthorDenied(t *testing.T) { + serverTool := MergePullRequest(translations.NullTranslationHelper) + client := gogithub.NewClient(MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, &gogithub.PullRequest{ + User: &gogithub.User{Login: gogithub.Ptr("alice")}, + }), + PutReposPullsMergeByOwnerByRepoByPullNumber: func(w http.ResponseWriter, _ *http.Request) { + t.Fatal("merge endpoint should not be called when PR author is denied") + }, + })) + deps := BaseDeps{ + Client: client, + allowedPRAuthors: buildPRAuthorAllowlist([]string{"renovate[bot]"}), + } + handler := serverTool.Handler(deps) + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }) + + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.True(t, result.IsError) + assert.Contains(t, getErrorResult(t, result).Text, `pull request author "alice" is not in --allowed-pr-authors`) +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 0065b25a92..e5095a7c87 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -838,6 +838,10 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) inventory.ServerToo return utils.NewToolResultError("No update parameters provided."), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + // Handle REST API updates (title, body, state, base, maintainer_can_modify) if restUpdateNeeded { client, err := deps.GetClient(ctx) @@ -1060,6 +1064,10 @@ func AddReplyToPullRequestComment(t translations.TranslationHelperFunc) inventor return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil @@ -1311,6 +1319,10 @@ func MergePullRequest(t translations.TranslationHelperFunc) inventory.ServerTool return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + options := &github.PullRequestOptions{ CommitTitle: commitTitle, MergeMethod: mergeMethod, @@ -1468,6 +1480,10 @@ func UpdatePullRequestBranch(t translations.TranslationHelperFunc) inventory.Ser opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil @@ -1586,6 +1602,13 @@ Available methods: return utils.NewToolResultError(err.Error()), nil, nil } + switch params.Method { + case "create", "submit_pending", "delete_pending": + if result, err := enforcePRAuthorAllowlist(ctx, deps, params.Owner, params.Repo, int(params.PullNumber), nil); result != nil || err != nil { + return result, nil, err + } + } + // Given our owner, repo and PR number, lookup the GQL ID of the PR. client, err := deps.GetGQLClient(ctx) if err != nil { @@ -2090,6 +2113,10 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) inventory.S return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, params.Owner, params.Repo, int(params.PullNumber), nil); result != nil || err != nil { + return result, nil, err + } + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil diff --git a/pkg/github/pullrequests_granular.go b/pkg/github/pullrequests_granular.go index 4a616f1b25..e897f5ca0b 100644 --- a/pkg/github/pullrequests_granular.go +++ b/pkg/github/pullrequests_granular.go @@ -77,6 +77,10 @@ func prUpdateTool( return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + prReq, err := buildRequest(args) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -218,6 +222,10 @@ func GranularUpdatePullRequestDraftState(t translations.TranslationHelperFunc) i return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil @@ -326,6 +334,10 @@ func GranularRequestPullRequestReviewers(t translations.TranslationHelperFunc) i return utils.NewToolResultError("missing required parameter: reviewers"), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil @@ -395,6 +407,10 @@ func GranularCreatePullRequestReview(t translations.TranslationHelperFunc) inven event, _ := OptionalParam[string](args, "event") commitID, _ := OptionalParam[string](args, "commitID") + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil @@ -465,6 +481,10 @@ func GranularSubmitPendingPullRequestReview(t translations.TranslationHelperFunc } body, _ := OptionalParam[string](args, "body") + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil @@ -522,6 +542,10 @@ func GranularDeletePendingPullRequestReview(t translations.TranslationHelperFunc return utils.NewToolResultError(err.Error()), nil, nil } + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil @@ -606,6 +630,10 @@ func GranularAddPullRequestReviewComment(t translations.TranslationHelperFunc) i } startSide, _ := OptionalParam[string](args, "startSide") + if result, err := enforcePRAuthorAllowlist(ctx, deps, owner, repo, pullNumber, nil); result != nil || err != nil { + return result, nil, err + } + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil diff --git a/pkg/github/server.go b/pkg/github/server.go index ee41e90e9e..6beae3028e 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -54,6 +54,10 @@ type MCPServerConfig struct { // LockdownMode indicates if we should enable lockdown mode LockdownMode bool + // AllowedPRAuthors restricts mutating pull request tools to PRs authored by + // one of these GitHub logins. Empty means unrestricted. + AllowedPRAuthors []string + // InsidersMode indicates if we should enable experimental features InsidersMode bool diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 264ffa50fe..2979ee9e9c 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -71,6 +71,9 @@ func (s stubDeps) Logger(_ context.Context) *slog.Logger { func (s stubDeps) Metrics(ctx context.Context) metrics.Metrics { return s.obsv.Metrics(ctx) } +func (s stubDeps) IsPRAuthorAllowed(_ string) (bool, bool) { + return true, false +} // Helper functions to create stub client functions for error testing diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 559088f6d6..96364f1084 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -217,6 +217,7 @@ func AllTools(t translations.TranslationHelperFunc) []inventory.ServerTool { // Organization tools SearchOrgs(t), + CheckOrgMembership(t), // Pull request tools PullRequestRead(t), diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 1ae4713216..a6e868adfa 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -127,7 +127,7 @@ func NewHTTPMcpHandler( func (h *Handler) RegisterMiddleware(r chi.Router) { r.Use( - middleware.ExtractUserToken(h.oauthCfg), + middleware.ExtractUserToken(h.oauthCfg, h.config.Token), middleware.WithRequestConfig, middleware.WithMCPParse(), middleware.WithPATScopes(h.logger, h.scopeFetcher), diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 012bbabef2..d802dc6971 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -10,7 +10,7 @@ import ( "github.com/github/github-mcp-server/pkg/utils" ) -func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { +func ExtractUserToken(oauthCfg *oauth.Config, defaultToken ...string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -27,6 +27,20 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec if errors.Is(err, utils.ErrMissingAuthorizationHeader) { + if len(defaultToken) > 0 && defaultToken[0] != "" { + tokenType, err := utils.ParseToken(defaultToken[0]) + if err != nil { + http.Error(w, fmt.Sprintf("default token is invalid: %v", err), http.StatusInternalServerError) + return + } + + ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ + Token: defaultToken[0], + TokenType: tokenType, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } sendAuthChallenge(w, r, oauthCfg) return } diff --git a/pkg/http/middleware/token_test.go b/pkg/http/middleware/token_test.go index fa8f0ee98e..c4f68c20fc 100644 --- a/pkg/http/middleware/token_test.go +++ b/pkg/http/middleware/token_test.go @@ -232,6 +232,55 @@ func TestExtractUserToken_NilOAuthConfig(t *testing.T) { assert.Equal(t, utils.TokenTypePersonalAccessToken, capturedTokenInfo.TokenType) } +func TestExtractUserToken_DefaultTokenFallback(t *testing.T) { + var capturedTokenInfo *ghcontext.TokenInfo + var tokenInfoCaptured bool + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTokenInfo, tokenInfoCaptured = ghcontext.GetTokenInfo(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := ExtractUserToken(nil, "ghp_defaulttokenxxxxxxxxxxxxxxxxxxxxxxxx") + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + require.True(t, tokenInfoCaptured) + require.NotNil(t, capturedTokenInfo) + assert.Equal(t, utils.TokenTypePersonalAccessToken, capturedTokenInfo.TokenType) + assert.Equal(t, "ghp_defaulttokenxxxxxxxxxxxxxxxxxxxxxxxx", capturedTokenInfo.Token) +} + +func TestExtractUserToken_RequestTokenOverridesDefaultToken(t *testing.T) { + var capturedTokenInfo *ghcontext.TokenInfo + var tokenInfoCaptured bool + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTokenInfo, tokenInfoCaptured = ghcontext.GetTokenInfo(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := ExtractUserToken(nil, "ghp_defaulttokenxxxxxxxxxxxxxxxxxxxxxxxx") + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set(headers.AuthorizationHeader, "Bearer gho_requesttokenxxxxxxxxxxxxxxxxxxxxxxxx") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + require.True(t, tokenInfoCaptured) + require.NotNil(t, capturedTokenInfo) + assert.Equal(t, utils.TokenTypeOAuthAccessToken, capturedTokenInfo.TokenType) + assert.Equal(t, "gho_requesttokenxxxxxxxxxxxxxxxxxxxxxxxx", capturedTokenInfo.Token) +} + func TestExtractUserToken_MissingAuthHeader_WWWAuthenticateFormat(t *testing.T) { oauthCfg := &oauth.Config{ BaseURL: "https://api.example.com", diff --git a/pkg/http/server.go b/pkg/http/server.go index f7cdaf9093..030906aa43 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -32,6 +32,10 @@ type ServerConfig struct { // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) Host string + // GitHub Token to use for requests that do not provide Authorization. + // If empty, HTTP requests must provide their own Authorization header. + Token string + // Port to listen on (default: 8082) Port int @@ -59,6 +63,10 @@ type ServerConfig struct { // LockdownMode indicates if we should enable lockdown mode LockdownMode bool + // AllowedPRAuthors restricts mutating pull request tools to PRs authored by + // one of these GitHub logins. Empty means unrestricted. + AllowedPRAuthors []string + // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration @@ -111,6 +119,12 @@ func RunHTTPServer(cfg ServerConfig) error { } logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "lockdownEnabled", cfg.LockdownMode, "readOnly", cfg.ReadOnly, "insidersMode", cfg.InsidersMode) + if len(cfg.AllowedPRAuthors) > 0 { + logger.Info("PR author allowlist enforced", "authors", cfg.AllowedPRAuthors) + } + if cfg.Token != "" { + logger.Warn("HTTP default token fallback enabled; requests without Authorization use the server token. Only expose this endpoint on a trusted network.") + } apiHost, err := utils.NewAPIHost(cfg.Host) if err != nil { @@ -140,6 +154,7 @@ func RunHTTPServer(cfg ServerConfig) error { cfg.ContentWindowSize, featureChecker, obs, + cfg.AllowedPRAuthors, ) // Initialize the global tool scope map diff --git a/pkg/utils/token.go b/pkg/utils/token.go index 8933fb0bda..97331bce76 100644 --- a/pkg/utils/token.go +++ b/pkg/utils/token.go @@ -60,16 +60,22 @@ func ParseAuthorizationHeader(req *http.Request) (tokenType TokenType, token str } } + tokenType, err := ParseToken(token) + return tokenType, token, err +} + +// ParseToken identifies a GitHub API token type from the token value. +func ParseToken(token string) (TokenType, error) { for prefix, tokenType := range supportedGitHubPrefixes { if strings.HasPrefix(token, prefix) { - return tokenType, token, nil + return tokenType, nil } } matchesOldTokenPattern := oldPatternRegexp.MatchString(token) if matchesOldTokenPattern { - return TokenTypePersonalAccessToken, token, nil + return TokenTypePersonalAccessToken, nil } - return 0, "", ErrBadAuthorizationHeader + return 0, ErrBadAuthorizationHeader }