From 00d67273ed4684a1c2d5f284a296aa0410c948b7 Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Wed, 25 Mar 2026 14:58:15 +0000 Subject: [PATCH 1/2] Refactor: extract oauth from login --- cmd/login/login.go | 193 ++---------------- cmd/login/login_test.go | 18 +- internal/oauth/browser.go | 20 ++ internal/oauth/exchange.go | 59 ++++++ internal/oauth/exchange_test.go | 50 +++++ .../oauth}/htmlPages/error.html | 0 .../oauth}/htmlPages/output.css | 0 .../oauth}/htmlPages/success.html | 0 .../oauth}/htmlPages/waiting.html | 0 internal/oauth/pages.go | 87 ++++++++ internal/oauth/pkce.go | 20 ++ internal/oauth/pkce_test.go | 22 ++ internal/oauth/server.go | 24 +++ internal/oauth/state.go | 17 ++ 14 files changed, 323 insertions(+), 187 deletions(-) create mode 100644 internal/oauth/browser.go create mode 100644 internal/oauth/exchange.go create mode 100644 internal/oauth/exchange_test.go rename {cmd/login => internal/oauth}/htmlPages/error.html (100%) rename {cmd/login => internal/oauth}/htmlPages/output.css (100%) rename {cmd/login => internal/oauth}/htmlPages/success.html (100%) rename {cmd/login => internal/oauth}/htmlPages/waiting.html (100%) create mode 100644 internal/oauth/pages.go create mode 100644 internal/oauth/pkce.go create mode 100644 internal/oauth/pkce_test.go create mode 100644 internal/oauth/server.go create mode 100644 internal/oauth/state.go diff --git a/cmd/login/login.go b/cmd/login/login.go index 9251c78b..0ee07280 100644 --- a/cmd/login/login.go +++ b/cmd/login/login.go @@ -2,17 +2,10 @@ package login import ( "context" - "crypto/rand" - "crypto/sha256" - "embed" - "encoding/base64" - "encoding/json" "fmt" - "io" "net" "net/http" "net/url" - "os/exec" rt "runtime" "strings" "time" @@ -24,28 +17,19 @@ import ( "github.com/smartcontractkit/cre-cli/internal/constants" "github.com/smartcontractkit/cre-cli/internal/credentials" "github.com/smartcontractkit/cre-cli/internal/environments" + "github.com/smartcontractkit/cre-cli/internal/oauth" "github.com/smartcontractkit/cre-cli/internal/runtime" "github.com/smartcontractkit/cre-cli/internal/tenantctx" "github.com/smartcontractkit/cre-cli/internal/ui" ) var ( - httpClient = &http.Client{Timeout: 10 * time.Second} - errorPage = "htmlPages/error.html" - successPage = "htmlPages/success.html" - waitingPage = "htmlPages/waiting.html" - stylePage = "htmlPages/output.css" - // OrgMembershipErrorSubstring is the error message substring returned by Auth0 // when a user doesn't belong to any organization during the auth flow. // This typically happens during sign-up when the organization hasn't been created yet. OrgMembershipErrorSubstring = "user does not belong to any organization" ) -//go:embed htmlPages/*.html -//go:embed htmlPages/*.css -var htmlFiles embed.FS - func New(runtimeCtx *runtime.Context) *cobra.Command { cmd := &cobra.Command{ Use: "login", @@ -102,7 +86,7 @@ func (h *handler) execute() error { // Use spinner for the token exchange h.spinner.Start("Exchanging authorization code...") - tokenSet, err := h.exchangeCodeForTokens(context.Background(), code) + tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier) if err != nil { h.spinner.StopAll() h.log.Error().Err(err).Msg("code exchange failed") @@ -162,13 +146,13 @@ func (h *handler) startAuthFlow() (string, error) { } }() - verifier, challenge, err := generatePKCE() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { h.spinner.Stop() return "", err } h.lastPKCEVerifier = verifier - h.lastState = randomState() + h.lastState = oauth.RandomState() authURL := h.buildAuthURL(challenge, h.lastState) @@ -180,7 +164,7 @@ func (h *handler) startAuthFlow() (string, error) { ui.URL(authURL) ui.Line() - if err := openBrowser(authURL, rt.GOOS); err != nil { + if err := oauth.OpenBrowser(authURL, rt.GOOS); err != nil { ui.Warning("Could not open browser automatically") ui.Dim("Please open the URL above in your browser") ui.Line() @@ -199,19 +183,7 @@ func (h *handler) startAuthFlow() (string, error) { } func (h *handler) setupServer(codeCh chan string) (*http.Server, net.Listener, error) { - mux := http.NewServeMux() - mux.HandleFunc("/callback", h.callbackHandler(codeCh)) - - // TODO: Add a fallback port in case the default port is in use - listener, err := net.Listen("tcp", constants.AuthListenAddr) - if err != nil { - return nil, nil, fmt.Errorf("failed to listen on %s: %w", constants.AuthListenAddr, err) - } - - return &http.Server{ - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, - }, listener, nil + return oauth.NewCallbackHTTPServer(constants.AuthListenAddr, h.callbackHandler(codeCh)) } func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc { @@ -225,120 +197,52 @@ func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc { if strings.Contains(errorDesc, OrgMembershipErrorSubstring) { if h.retryCount >= maxOrgNotFoundRetries { h.log.Error().Int("retries", h.retryCount).Msg("organization setup timed out after maximum retries") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } // Generate new authentication credentials for the retry - verifier, challenge, err := generatePKCE() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { h.log.Error().Err(err).Msg("failed to prepare authentication retry") - h.serveEmbeddedHTML(w, errorPage, http.StatusInternalServerError) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusInternalServerError) return } h.lastPKCEVerifier = verifier - h.lastState = randomState() + h.lastState = oauth.RandomState() h.retryCount++ // Build the new auth URL for redirect authURL := h.buildAuthURL(challenge, h.lastState) h.log.Debug().Int("attempt", h.retryCount).Int("max", maxOrgNotFoundRetries).Msg("organization setup in progress, retrying") - h.serveWaitingPage(w, authURL) + oauth.ServeWaitingPage(h.log, w, authURL) return } // Generic Auth0 error h.log.Error().Str("error", errorParam).Str("description", errorDesc).Msg("auth error in callback") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } if st := r.URL.Query().Get("state"); st == "" || h.lastState == "" || st != h.lastState { h.log.Error().Msg("invalid state in response") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } code := r.URL.Query().Get("code") if code == "" { h.log.Error().Msg("no code in response") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } - h.serveEmbeddedHTML(w, successPage, http.StatusOK) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageSuccess, http.StatusOK) codeCh <- code } } -func (h *handler) serveEmbeddedHTML(w http.ResponseWriter, filePath string, status int) { - htmlContent, err := htmlFiles.ReadFile(filePath) - if err != nil { - h.log.Error().Err(err).Str("file", filePath).Msg("failed to read embedded HTML file") - h.sendHTTPError(w) - return - } - - cssContent, err := htmlFiles.ReadFile(stylePage) - if err != nil { - h.log.Error().Err(err).Str("file", stylePage).Msg("failed to read embedded CSS file") - h.sendHTTPError(w) - return - } - - modified := strings.Replace( - string(htmlContent), - ``, - fmt.Sprintf("", string(cssContent)), - 1, - ) - - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(status) - if _, err := w.Write([]byte(modified)); err != nil { - h.log.Error().Err(err).Msg("failed to write HTML response") - } -} - -// serveWaitingPage serves the waiting page with the redirect URL injected. -// This is used when handling organization membership errors during sign-up flow. -func (h *handler) serveWaitingPage(w http.ResponseWriter, redirectURL string) { - htmlContent, err := htmlFiles.ReadFile(waitingPage) - if err != nil { - h.log.Error().Err(err).Str("file", waitingPage).Msg("failed to read waiting page HTML file") - h.sendHTTPError(w) - return - } - - cssContent, err := htmlFiles.ReadFile(stylePage) - if err != nil { - h.log.Error().Err(err).Str("file", stylePage).Msg("failed to read embedded CSS file") - h.sendHTTPError(w) - return - } - - // Inject CSS inline - modified := strings.Replace( - string(htmlContent), - ``, - fmt.Sprintf("", string(cssContent)), - 1, - ) - - // Inject the redirect URL - modified = strings.Replace(modified, "{{REDIRECT_URL}}", redirectURL, 1) - - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(modified)); err != nil { - h.log.Error().Err(err).Msg("failed to write waiting page response") - } -} - -func (h *handler) sendHTTPError(w http.ResponseWriter) { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) -} - func (h *handler) buildAuthURL(codeChallenge, state string) string { params := url.Values{} params.Set("client_id", h.environmentSet.ClientID) @@ -355,41 +259,6 @@ func (h *handler) buildAuthURL(codeChallenge, state string) string { return h.environmentSet.AuthBase + constants.AuthAuthorizePath + "?" + params.Encode() } -func (h *handler) exchangeCodeForTokens(ctx context.Context, code string) (*credentials.CreLoginTokenSet, error) { - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", h.environmentSet.ClientID) - form.Set("code", code) - form.Set("redirect_uri", constants.AuthRedirectURI) - form.Set("code_verifier", h.lastPKCEVerifier) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.environmentSet.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := httpClient.Do(req) // #nosec G704 -- URL is from trusted environment config - if err != nil { - return nil, fmt.Errorf("perform request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body) - } - - var tokenSet credentials.CreLoginTokenSet - if err := json.Unmarshal(body, &tokenSet); err != nil { - return nil, fmt.Errorf("unmarshal token set: %w", err) - } - return &tokenSet, nil -} - func (h *handler) fetchTenantConfig(tokenSet *credentials.CreLoginTokenSet) error { creds := &credentials.Credentials{ Tokens: tokenSet, @@ -404,35 +273,3 @@ func (h *handler) fetchTenantConfig(tokenSet *credentials.CreLoginTokenSet) erro return tenantctx.FetchAndWriteContext(context.Background(), gqlClient, envName, h.log) } - -func openBrowser(urlStr string, goos string) error { - switch goos { - case "darwin": - return exec.Command("open", urlStr).Start() - case "linux": - return exec.Command("xdg-open", urlStr).Start() - case "windows": - return exec.Command("rundll32", "url.dll,FileProtocolHandler", urlStr).Start() - default: - return fmt.Errorf("unsupported OS: %s", goos) - } -} - -func generatePKCE() (verifier, challenge string, err error) { - b := make([]byte, 32) - if _, err = rand.Read(b); err != nil { - return "", "", err - } - verifier = base64.RawURLEncoding.EncodeToString(b) - sum := sha256.Sum256([]byte(verifier)) - challenge = base64.RawURLEncoding.EncodeToString(sum[:]) - return verifier, challenge, nil -} - -func randomState() string { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) - } - return base64.RawURLEncoding.EncodeToString(b) -} diff --git a/cmd/login/login_test.go b/cmd/login/login_test.go index 782f2d18..90ce4c66 100644 --- a/cmd/login/login_test.go +++ b/cmd/login/login_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/credentials" "github.com/smartcontractkit/cre-cli/internal/environments" + "github.com/smartcontractkit/cre-cli/internal/oauth" "github.com/smartcontractkit/cre-cli/internal/ui" ) @@ -51,9 +52,9 @@ func TestSaveCredentials_WritesYAML(t *testing.T) { } func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) { - verifier, challenge, err := generatePKCE() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { - t.Fatalf("generatePKCE error: %v", err) + t.Fatalf("GeneratePKCE error: %v", err) } if verifier == "" || challenge == "" { t.Error("PKCE verifier or challenge is empty") @@ -61,8 +62,8 @@ func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) { } func TestRandomState_IsRandomAndNonEmpty(t *testing.T) { - state1 := randomState() - state2 := randomState() + state1 := oauth.RandomState() + state2 := oauth.RandomState() if state1 == "" || state2 == "" { t.Error("randomState returned empty string") } @@ -72,16 +73,16 @@ func TestRandomState_IsRandomAndNonEmpty(t *testing.T) { } func TestOpenBrowser_UnsupportedOS(t *testing.T) { - err := openBrowser("http://example.com", "plan9") + err := oauth.OpenBrowser("http://example.com", "plan9") if err == nil || !strings.Contains(err.Error(), "unsupported OS") { t.Errorf("expected unsupported OS error, got %v", err) } } func TestServeEmbeddedHTML_ErrorOnMissingFile(t *testing.T) { - h := &handler{log: &zerolog.Logger{}, spinner: ui.NewSpinner()} + log := zerolog.Nop() w := httptest.NewRecorder() - h.serveEmbeddedHTML(w, "htmlPages/doesnotexist.html", http.StatusOK) + oauth.ServeEmbeddedHTML(&log, w, "htmlPages/doesnotexist.html", http.StatusOK) resp := w.Result() if resp.StatusCode != http.StatusInternalServerError { t.Errorf("expected 500 error, got %d", resp.StatusCode) @@ -274,12 +275,11 @@ func TestCallbackHandler_GenericAuth0Error(t *testing.T) { func TestServeWaitingPage(t *testing.T) { logger := zerolog.Nop() - h := &handler{log: &logger, spinner: ui.NewSpinner()} w := httptest.NewRecorder() redirectURL := "https://auth.example.com/authorize?client_id=test&state=abc123" - h.serveWaitingPage(w, redirectURL) + oauth.ServeWaitingPage(&logger, w, redirectURL) resp := w.Result() body, _ := io.ReadAll(resp.Body) diff --git a/internal/oauth/browser.go b/internal/oauth/browser.go new file mode 100644 index 00000000..99e13424 --- /dev/null +++ b/internal/oauth/browser.go @@ -0,0 +1,20 @@ +package oauth + +import ( + "fmt" + "os/exec" +) + +// OpenBrowser opens urlStr in the default browser for the given GOOS value. +func OpenBrowser(urlStr string, goos string) error { + switch goos { + case "darwin": + return exec.Command("open", urlStr).Start() + case "linux": + return exec.Command("xdg-open", urlStr).Start() + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", urlStr).Start() + default: + return fmt.Errorf("unsupported OS: %s", goos) + } +} diff --git a/internal/oauth/exchange.go b/internal/oauth/exchange.go new file mode 100644 index 00000000..3af69e3d --- /dev/null +++ b/internal/oauth/exchange.go @@ -0,0 +1,59 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/internal/credentials" + "github.com/smartcontractkit/cre-cli/internal/environments" +) + +// DefaultHTTPClient is used for token exchange when no client is supplied. +var DefaultHTTPClient = &http.Client{Timeout: 10 * time.Second} + +// ExchangeAuthorizationCode exchanges an OAuth authorization code for tokens using +// environment credentials (AuthBase, ClientID) and PKCE code_verifier. +func ExchangeAuthorizationCode(ctx context.Context, httpClient *http.Client, env *environments.EnvironmentSet, code, codeVerifier string) (*credentials.CreLoginTokenSet, error) { + if httpClient == nil { + httpClient = DefaultHTTPClient + } + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", env.ClientID) + form.Set("code", code) + form.Set("redirect_uri", constants.AuthRedirectURI) + form.Set("code_verifier", codeVerifier) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, env.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := httpClient.Do(req) // #nosec G704 -- URL is from trusted environment config + if err != nil { + return nil, fmt.Errorf("perform request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body) + } + + var tokenSet credentials.CreLoginTokenSet + if err := json.Unmarshal(body, &tokenSet); err != nil { + return nil, fmt.Errorf("unmarshal token set: %w", err) + } + return &tokenSet, nil +} diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go new file mode 100644 index 00000000..fcf19faf --- /dev/null +++ b/internal/oauth/exchange_test.go @@ -0,0 +1,50 @@ +package oauth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/internal/credentials" + "github.com/smartcontractkit/cre-cli/internal/environments" +) + +func TestExchangeAuthorizationCode(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + assert.Equal(t, "authorization_code", r.Form.Get("grant_type")) + assert.Equal(t, "cid", r.Form.Get("client_id")) + assert.Equal(t, "auth-code", r.Form.Get("code")) + assert.Equal(t, constants.AuthRedirectURI, r.Form.Get("redirect_uri")) + assert.Equal(t, "verifier", r.Form.Get("code_verifier")) + + _ = json.NewEncoder(w).Encode(credentials.CreLoginTokenSet{ + AccessToken: "a", + TokenType: "Bearer", + }) + })) + defer ts.Close() + + env := &environments.EnvironmentSet{ + AuthBase: ts.URL, + ClientID: "cid", + } + + tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "auth-code", "verifier") + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "a", tok.AccessToken) +} diff --git a/cmd/login/htmlPages/error.html b/internal/oauth/htmlPages/error.html similarity index 100% rename from cmd/login/htmlPages/error.html rename to internal/oauth/htmlPages/error.html diff --git a/cmd/login/htmlPages/output.css b/internal/oauth/htmlPages/output.css similarity index 100% rename from cmd/login/htmlPages/output.css rename to internal/oauth/htmlPages/output.css diff --git a/cmd/login/htmlPages/success.html b/internal/oauth/htmlPages/success.html similarity index 100% rename from cmd/login/htmlPages/success.html rename to internal/oauth/htmlPages/success.html diff --git a/cmd/login/htmlPages/waiting.html b/internal/oauth/htmlPages/waiting.html similarity index 100% rename from cmd/login/htmlPages/waiting.html rename to internal/oauth/htmlPages/waiting.html diff --git a/internal/oauth/pages.go b/internal/oauth/pages.go new file mode 100644 index 00000000..31d07220 --- /dev/null +++ b/internal/oauth/pages.go @@ -0,0 +1,87 @@ +package oauth + +import ( + "embed" + "fmt" + "net/http" + "strings" + + "github.com/rs/zerolog" +) + +const ( + PageError = "htmlPages/error.html" + PageSuccess = "htmlPages/success.html" + PageWaiting = "htmlPages/waiting.html" + StylePage = "htmlPages/output.css" +) + +//go:embed htmlPages/*.html +//go:embed htmlPages/*.css +var htmlFiles embed.FS + +// ServeEmbeddedHTML serves an embedded HTML page with inline CSS. +func ServeEmbeddedHTML(log *zerolog.Logger, w http.ResponseWriter, filePath string, status int) { + htmlContent, err := htmlFiles.ReadFile(filePath) + if err != nil { + log.Error().Err(err).Str("file", filePath).Msg("failed to read embedded HTML file") + sendHTTPError(w) + return + } + + cssContent, err := htmlFiles.ReadFile(StylePage) + if err != nil { + log.Error().Err(err).Str("file", StylePage).Msg("failed to read embedded CSS file") + sendHTTPError(w) + return + } + + modified := strings.Replace( + string(htmlContent), + ``, + fmt.Sprintf("", string(cssContent)), + 1, + ) + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(status) + if _, err := w.Write([]byte(modified)); err != nil { + log.Error().Err(err).Msg("failed to write HTML response") + } +} + +// ServeWaitingPage serves the waiting page with the redirect URL injected. +func ServeWaitingPage(log *zerolog.Logger, w http.ResponseWriter, redirectURL string) { + htmlContent, err := htmlFiles.ReadFile(PageWaiting) + if err != nil { + log.Error().Err(err).Str("file", PageWaiting).Msg("failed to read waiting page HTML file") + sendHTTPError(w) + return + } + + cssContent, err := htmlFiles.ReadFile(StylePage) + if err != nil { + log.Error().Err(err).Str("file", StylePage).Msg("failed to read embedded CSS file") + sendHTTPError(w) + return + } + + modified := strings.Replace( + string(htmlContent), + ``, + fmt.Sprintf("", string(cssContent)), + 1, + ) + + modified = strings.Replace(modified, "{{REDIRECT_URL}}", redirectURL, 1) + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(modified)); err != nil { + log.Error().Err(err).Msg("failed to write waiting page response") + } +} + +func sendHTTPError(w http.ResponseWriter) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) +} diff --git a/internal/oauth/pkce.go b/internal/oauth/pkce.go new file mode 100644 index 00000000..0ed0e8a0 --- /dev/null +++ b/internal/oauth/pkce.go @@ -0,0 +1,20 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +// GeneratePKCE returns an RFC 7636 S256 code verifier and code challenge. +func GeneratePKCE() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err = rand.Read(b); err != nil { + return "", "", fmt.Errorf("pkce random: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} diff --git a/internal/oauth/pkce_test.go b/internal/oauth/pkce_test.go new file mode 100644 index 00000000..50b7c376 --- /dev/null +++ b/internal/oauth/pkce_test.go @@ -0,0 +1,22 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGeneratePKCE_S256(t *testing.T) { + verifier, challenge, err := GeneratePKCE() + require.NoError(t, err) + require.NotEmpty(t, verifier) + require.NotEmpty(t, challenge) + + sum := sha256.Sum256([]byte(verifier)) + decoded, err := base64.RawURLEncoding.DecodeString(challenge) + require.NoError(t, err) + assert.Equal(t, sum[:], decoded) +} diff --git a/internal/oauth/server.go b/internal/oauth/server.go new file mode 100644 index 00000000..4ca2abe5 --- /dev/null +++ b/internal/oauth/server.go @@ -0,0 +1,24 @@ +package oauth + +import ( + "fmt" + "net" + "net/http" + "time" +) + +// NewCallbackHTTPServer listens on listenAddr and serves callback on /callback. +func NewCallbackHTTPServer(listenAddr string, callback http.HandlerFunc) (*http.Server, net.Listener, error) { + mux := http.NewServeMux() + mux.HandleFunc("/callback", callback) + + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, nil, fmt.Errorf("failed to listen on %s: %w", listenAddr, err) + } + + return &http.Server{ + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + }, listener, nil +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go new file mode 100644 index 00000000..4019de61 --- /dev/null +++ b/internal/oauth/state.go @@ -0,0 +1,17 @@ +package oauth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "time" +) + +// RandomState returns a random OAuth state value for CSRF protection. +func RandomState() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return base64.RawURLEncoding.EncodeToString(b) +} From fd957ca24da181c451c0450be069bc54c29f53c6 Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Wed, 25 Mar 2026 15:18:43 +0000 Subject: [PATCH 2/2] Lint --- internal/oauth/exchange_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go index fcf19faf..a7923a36 100644 --- a/internal/oauth/exchange_test.go +++ b/internal/oauth/exchange_test.go @@ -21,6 +21,7 @@ func TestExchangeAuthorizationCode(t *testing.T) { http.Error(w, "method", http.StatusMethodNotAllowed) return } + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) if err := r.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -32,7 +33,7 @@ func TestExchangeAuthorizationCode(t *testing.T) { assert.Equal(t, "verifier", r.Form.Get("code_verifier")) _ = json.NewEncoder(w).Encode(credentials.CreLoginTokenSet{ - AccessToken: "a", + AccessToken: "a", // #nosec G101 G117 -- test fixture, not a real credential TokenType: "Bearer", }) }))