Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 15 additions & 178 deletions cmd/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@ package login

import (
"context"
"crypto/rand"
"crypto/sha256"
"embed"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
rt "runtime"
"strings"
"time"
Expand All @@ -24,28 +17,19 @@ import (
"github.com/smartcontractkit/cre-cli/internal/constants"
"github.com/smartcontractkit/cre-cli/internal/credentials"
"github.com/smartcontractkit/cre-cli/internal/environments"
"github.com/smartcontractkit/cre-cli/internal/oauth"
"github.com/smartcontractkit/cre-cli/internal/runtime"
"github.com/smartcontractkit/cre-cli/internal/tenantctx"
"github.com/smartcontractkit/cre-cli/internal/ui"
)

var (
httpClient = &http.Client{Timeout: 10 * time.Second}
errorPage = "htmlPages/error.html"
successPage = "htmlPages/success.html"
waitingPage = "htmlPages/waiting.html"
stylePage = "htmlPages/output.css"

// OrgMembershipErrorSubstring is the error message substring returned by Auth0
// when a user doesn't belong to any organization during the auth flow.
// This typically happens during sign-up when the organization hasn't been created yet.
OrgMembershipErrorSubstring = "user does not belong to any organization"
)

//go:embed htmlPages/*.html
//go:embed htmlPages/*.css
var htmlFiles embed.FS

func New(runtimeCtx *runtime.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "login",
Expand Down Expand Up @@ -102,7 +86,7 @@ func (h *handler) execute() error {

// Use spinner for the token exchange
h.spinner.Start("Exchanging authorization code...")
tokenSet, err := h.exchangeCodeForTokens(context.Background(), code)
tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier)
if err != nil {
h.spinner.StopAll()
h.log.Error().Err(err).Msg("code exchange failed")
Expand Down Expand Up @@ -162,13 +146,13 @@ func (h *handler) startAuthFlow() (string, error) {
}
}()

verifier, challenge, err := generatePKCE()
verifier, challenge, err := oauth.GeneratePKCE()
if err != nil {
h.spinner.Stop()
return "", err
}
h.lastPKCEVerifier = verifier
h.lastState = randomState()
h.lastState = oauth.RandomState()

authURL := h.buildAuthURL(challenge, h.lastState)

Expand All @@ -180,7 +164,7 @@ func (h *handler) startAuthFlow() (string, error) {
ui.URL(authURL)
ui.Line()

if err := openBrowser(authURL, rt.GOOS); err != nil {
if err := oauth.OpenBrowser(authURL, rt.GOOS); err != nil {
ui.Warning("Could not open browser automatically")
ui.Dim("Please open the URL above in your browser")
ui.Line()
Expand All @@ -199,19 +183,7 @@ func (h *handler) startAuthFlow() (string, error) {
}

func (h *handler) setupServer(codeCh chan string) (*http.Server, net.Listener, error) {
mux := http.NewServeMux()
mux.HandleFunc("/callback", h.callbackHandler(codeCh))

// TODO: Add a fallback port in case the default port is in use
listener, err := net.Listen("tcp", constants.AuthListenAddr)
if err != nil {
return nil, nil, fmt.Errorf("failed to listen on %s: %w", constants.AuthListenAddr, err)
}

return &http.Server{
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}, listener, nil
return oauth.NewCallbackHTTPServer(constants.AuthListenAddr, h.callbackHandler(codeCh))
}

func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc {
Expand All @@ -225,120 +197,52 @@ func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc {
if strings.Contains(errorDesc, OrgMembershipErrorSubstring) {
if h.retryCount >= maxOrgNotFoundRetries {
h.log.Error().Int("retries", h.retryCount).Msg("organization setup timed out after maximum retries")
h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest)
oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest)
return
}

// Generate new authentication credentials for the retry
verifier, challenge, err := generatePKCE()
verifier, challenge, err := oauth.GeneratePKCE()
if err != nil {
h.log.Error().Err(err).Msg("failed to prepare authentication retry")
h.serveEmbeddedHTML(w, errorPage, http.StatusInternalServerError)
oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusInternalServerError)
return
}
h.lastPKCEVerifier = verifier
h.lastState = randomState()
h.lastState = oauth.RandomState()
h.retryCount++

// Build the new auth URL for redirect
authURL := h.buildAuthURL(challenge, h.lastState)

h.log.Debug().Int("attempt", h.retryCount).Int("max", maxOrgNotFoundRetries).Msg("organization setup in progress, retrying")
h.serveWaitingPage(w, authURL)
oauth.ServeWaitingPage(h.log, w, authURL)
return
}

// Generic Auth0 error
h.log.Error().Str("error", errorParam).Str("description", errorDesc).Msg("auth error in callback")
h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest)
oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest)
return
}

if st := r.URL.Query().Get("state"); st == "" || h.lastState == "" || st != h.lastState {
h.log.Error().Msg("invalid state in response")
h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest)
oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
if code == "" {
h.log.Error().Msg("no code in response")
h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest)
oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest)
return
}

h.serveEmbeddedHTML(w, successPage, http.StatusOK)
oauth.ServeEmbeddedHTML(h.log, w, oauth.PageSuccess, http.StatusOK)
codeCh <- code
}
}

func (h *handler) serveEmbeddedHTML(w http.ResponseWriter, filePath string, status int) {
htmlContent, err := htmlFiles.ReadFile(filePath)
if err != nil {
h.log.Error().Err(err).Str("file", filePath).Msg("failed to read embedded HTML file")
h.sendHTTPError(w)
return
}

cssContent, err := htmlFiles.ReadFile(stylePage)
if err != nil {
h.log.Error().Err(err).Str("file", stylePage).Msg("failed to read embedded CSS file")
h.sendHTTPError(w)
return
}

modified := strings.Replace(
string(htmlContent),
`<link rel="stylesheet" href="./output.css" />`,
fmt.Sprintf("<style>%s</style>", string(cssContent)),
1,
)

w.Header().Set("Content-Type", "text/html")
w.WriteHeader(status)
if _, err := w.Write([]byte(modified)); err != nil {
h.log.Error().Err(err).Msg("failed to write HTML response")
}
}

// serveWaitingPage serves the waiting page with the redirect URL injected.
// This is used when handling organization membership errors during sign-up flow.
func (h *handler) serveWaitingPage(w http.ResponseWriter, redirectURL string) {
htmlContent, err := htmlFiles.ReadFile(waitingPage)
if err != nil {
h.log.Error().Err(err).Str("file", waitingPage).Msg("failed to read waiting page HTML file")
h.sendHTTPError(w)
return
}

cssContent, err := htmlFiles.ReadFile(stylePage)
if err != nil {
h.log.Error().Err(err).Str("file", stylePage).Msg("failed to read embedded CSS file")
h.sendHTTPError(w)
return
}

// Inject CSS inline
modified := strings.Replace(
string(htmlContent),
`<link rel="stylesheet" href="./output.css" />`,
fmt.Sprintf("<style>%s</style>", string(cssContent)),
1,
)

// Inject the redirect URL
modified = strings.Replace(modified, "{{REDIRECT_URL}}", redirectURL, 1)

w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(modified)); err != nil {
h.log.Error().Err(err).Msg("failed to write waiting page response")
}
}

func (h *handler) sendHTTPError(w http.ResponseWriter) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}

func (h *handler) buildAuthURL(codeChallenge, state string) string {
params := url.Values{}
params.Set("client_id", h.environmentSet.ClientID)
Expand All @@ -355,41 +259,6 @@ func (h *handler) buildAuthURL(codeChallenge, state string) string {
return h.environmentSet.AuthBase + constants.AuthAuthorizePath + "?" + params.Encode()
}

func (h *handler) exchangeCodeForTokens(ctx context.Context, code string) (*credentials.CreLoginTokenSet, error) {
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", h.environmentSet.ClientID)
form.Set("code", code)
form.Set("redirect_uri", constants.AuthRedirectURI)
form.Set("code_verifier", h.lastPKCEVerifier)

req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.environmentSet.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := httpClient.Do(req) // #nosec G704 -- URL is from trusted environment config
if err != nil {
return nil, fmt.Errorf("perform request: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
}

var tokenSet credentials.CreLoginTokenSet
if err := json.Unmarshal(body, &tokenSet); err != nil {
return nil, fmt.Errorf("unmarshal token set: %w", err)
}
return &tokenSet, nil
}

func (h *handler) fetchTenantConfig(tokenSet *credentials.CreLoginTokenSet) error {
creds := &credentials.Credentials{
Tokens: tokenSet,
Expand All @@ -404,35 +273,3 @@ func (h *handler) fetchTenantConfig(tokenSet *credentials.CreLoginTokenSet) erro

return tenantctx.FetchAndWriteContext(context.Background(), gqlClient, envName, h.log)
}

func openBrowser(urlStr string, goos string) error {
switch goos {
case "darwin":
return exec.Command("open", urlStr).Start()
case "linux":
return exec.Command("xdg-open", urlStr).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", urlStr).Start()
default:
return fmt.Errorf("unsupported OS: %s", goos)
}
}

func generatePKCE() (verifier, challenge string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return "", "", err
}
verifier = base64.RawURLEncoding.EncodeToString(b)
sum := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
return verifier, challenge, nil
}

func randomState() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return base64.RawURLEncoding.EncodeToString(b)
}
18 changes: 9 additions & 9 deletions cmd/login/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/smartcontractkit/cre-cli/internal/credentials"
"github.com/smartcontractkit/cre-cli/internal/environments"
"github.com/smartcontractkit/cre-cli/internal/oauth"
"github.com/smartcontractkit/cre-cli/internal/ui"
)

Expand Down Expand Up @@ -51,18 +52,18 @@ func TestSaveCredentials_WritesYAML(t *testing.T) {
}

func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) {
verifier, challenge, err := generatePKCE()
verifier, challenge, err := oauth.GeneratePKCE()
if err != nil {
t.Fatalf("generatePKCE error: %v", err)
t.Fatalf("GeneratePKCE error: %v", err)
}
if verifier == "" || challenge == "" {
t.Error("PKCE verifier or challenge is empty")
}
}

func TestRandomState_IsRandomAndNonEmpty(t *testing.T) {
state1 := randomState()
state2 := randomState()
state1 := oauth.RandomState()
state2 := oauth.RandomState()
if state1 == "" || state2 == "" {
t.Error("randomState returned empty string")
}
Expand All @@ -72,16 +73,16 @@ func TestRandomState_IsRandomAndNonEmpty(t *testing.T) {
}

func TestOpenBrowser_UnsupportedOS(t *testing.T) {
err := openBrowser("http://example.com", "plan9")
err := oauth.OpenBrowser("http://example.com", "plan9")
if err == nil || !strings.Contains(err.Error(), "unsupported OS") {
t.Errorf("expected unsupported OS error, got %v", err)
}
}

func TestServeEmbeddedHTML_ErrorOnMissingFile(t *testing.T) {
h := &handler{log: &zerolog.Logger{}, spinner: ui.NewSpinner()}
log := zerolog.Nop()
w := httptest.NewRecorder()
h.serveEmbeddedHTML(w, "htmlPages/doesnotexist.html", http.StatusOK)
oauth.ServeEmbeddedHTML(&log, w, "htmlPages/doesnotexist.html", http.StatusOK)
resp := w.Result()
if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected 500 error, got %d", resp.StatusCode)
Expand Down Expand Up @@ -274,12 +275,11 @@ func TestCallbackHandler_GenericAuth0Error(t *testing.T) {

func TestServeWaitingPage(t *testing.T) {
logger := zerolog.Nop()
h := &handler{log: &logger, spinner: ui.NewSpinner()}

w := httptest.NewRecorder()
redirectURL := "https://auth.example.com/authorize?client_id=test&state=abc123"

h.serveWaitingPage(w, redirectURL)
oauth.ServeWaitingPage(&logger, w, redirectURL)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)
Expand Down
20 changes: 20 additions & 0 deletions internal/oauth/browser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package oauth

import (
"fmt"
"os/exec"
)

// OpenBrowser opens urlStr in the default browser for the given GOOS value.
func OpenBrowser(urlStr string, goos string) error {
switch goos {
case "darwin":
return exec.Command("open", urlStr).Start()
case "linux":
return exec.Command("xdg-open", urlStr).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", urlStr).Start()
default:
return fmt.Errorf("unsupported OS: %s", goos)
}
}
Loading
Loading