Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func refreshAccessToken(
cfg *AppConfig,
refreshToken string,
) (*credstore.Token, error) {
ctx, cancel := context.WithTimeout(ctx, refreshTokenTimeout)
ctx, cancel := context.WithTimeout(ctx, cfg.RefreshTokenTimeout)
defer cancel()

data := url.Values{}
Expand All @@ -74,7 +74,7 @@ func refreshAccessToken(
data.Set("client_secret", cfg.ClientSecret)
}

tokenResp, err := doTokenExchange(ctx, cfg, cfg.ServerURL+"/oauth/token", data,
tokenResp, err := doTokenExchange(ctx, cfg, cfg.Endpoints.TokenURL, data,
func(errResp ErrorResponse, _ []byte) error {
if errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" {
return ErrRefreshTokenExpired
Expand All @@ -101,18 +101,18 @@ func refreshAccessToken(

// verifyToken verifies an access token with the OAuth server.
func verifyToken(ctx context.Context, cfg *AppConfig, accessToken string) (string, error) {
ctx, cancel := context.WithTimeout(ctx, tokenVerificationTimeout)
ctx, cancel := context.WithTimeout(ctx, cfg.TokenVerificationTimeout)
defer cancel()

resp, err := cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/tokeninfo",
resp, err := cfg.RetryClient.Get(ctx, cfg.Endpoints.TokenInfoURL,
retry.WithHeader("Authorization", "Bearer "+accessToken),
)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()

body, err := readResponseBody(resp)
body, err := readResponseBody(resp, cfg.MaxResponseBodySize)
if err != nil {
return "", err
}
Expand All @@ -131,7 +131,7 @@ func makeAPICallWithAutoRefresh(
storage *credstore.Token,
ui tui.Manager,
) error {
resp, err := cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/tokeninfo",
resp, err := cfg.RetryClient.Get(ctx, cfg.Endpoints.TokenInfoURL,
retry.WithHeader("Authorization", "Bearer "+storage.AccessToken),
)
if err != nil {
Expand All @@ -156,7 +156,7 @@ func makeAPICallWithAutoRefresh(

ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenRefreshedRetrying})

resp, err = cfg.RetryClient.Get(ctx, cfg.ServerURL+"/oauth/tokeninfo",
resp, err = cfg.RetryClient.Get(ctx, cfg.Endpoints.TokenInfoURL,
retry.WithHeader("Authorization", "Bearer "+storage.AccessToken),
)
if err != nil {
Expand All @@ -165,7 +165,7 @@ func makeAPICallWithAutoRefresh(
defer resp.Body.Close()
}

body, err := readResponseBody(resp)
body, err := readResponseBody(resp, cfg.MaxResponseBodySize)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions browser_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func buildAuthURL(cfg *AppConfig, state string, pkce *PKCEParams) string {
params.Set("state", state)
params.Set("code_challenge", pkce.Challenge)
params.Set("code_challenge_method", pkce.Method)
return cfg.ServerURL + "/oauth/authorize?" + params.Encode()
return cfg.Endpoints.AuthorizeURL + "?" + params.Encode()
}

// exchangeCode exchanges an authorization code for access + refresh tokens.
Expand All @@ -30,7 +30,7 @@ func exchangeCode(
cfg *AppConfig,
code, codeVerifier string,
) (*credstore.Token, error) {
ctx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout)
ctx, cancel := context.WithTimeout(ctx, cfg.TokenExchangeTimeout)
defer cancel()

data := url.Values{}
Expand All @@ -44,7 +44,7 @@ func exchangeCode(
data.Set("client_secret", cfg.ClientSecret)
}

tokenResp, err := doTokenExchange(ctx, cfg, cfg.ServerURL+"/oauth/token", data, nil)
tokenResp, err := doTokenExchange(ctx, cfg, cfg.Endpoints.TokenURL, data, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -144,7 +144,7 @@ func performBrowserFlowWithUpdates(
return
case <-ticker.C:
elapsed := time.Since(startTime)
progress := float64(elapsed) / float64(callbackTimeout)
progress := float64(elapsed) / float64(cfg.CallbackTimeout)
if progress > 1.0 {
progress = 1.0
}
Expand All @@ -153,7 +153,7 @@ func performBrowserFlowWithUpdates(
Progress: progress,
Data: map[string]any{
"elapsed": elapsed,
"timeout": callbackTimeout,
"timeout": cfg.CallbackTimeout,
},
}
select {
Expand All @@ -167,7 +167,7 @@ func performBrowserFlowWithUpdates(
}
}()

storage, err := startCallbackServer(ctx, cfg.CallbackPort, state,
storage, err := startCallbackServer(ctx, cfg.CallbackPort, state, cfg.CallbackTimeout,
func(callbackCtx context.Context, code string) (*credstore.Token, error) {
updates <- tui.FlowUpdate{
Type: tui.StepStart,
Expand Down
12 changes: 4 additions & 8 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ func sanitizeTokenExchangeError(_ error) string {
return "Token exchange failed. Please try again."
}

const (
// callbackTimeout is how long we wait for the browser to deliver the code.
callbackTimeout = 2 * time.Minute
)

// ErrCallbackTimeout is returned when no browser callback is received within callbackTimeout.
// ErrCallbackTimeout is returned when no browser callback is received within the callback timeout.
// Callers can use errors.Is to distinguish a timeout from other authorization errors
// and decide whether to fall back to Device Code Flow.
var ErrCallbackTimeout = errors.New("browser authorization timed out")
Expand All @@ -76,6 +71,7 @@ type callbackResult struct {
//
// The server shuts itself down after the first request.
func startCallbackServer(ctx context.Context, port int, expectedState string,
cbTimeout time.Duration,
exchangeFn func(context.Context, string) (*credstore.Token, error),
) (*credstore.Token, error) {
resultCh := make(chan callbackResult, 1)
Expand Down Expand Up @@ -158,7 +154,7 @@ func startCallbackServer(ctx context.Context, port int, expectedState string,
_ = srv.Shutdown(shutdownCtx)
}()

timer := time.NewTimer(callbackTimeout)
timer := time.NewTimer(cbTimeout)
defer timer.Stop()

select {
Expand All @@ -172,7 +168,7 @@ func startCallbackServer(ctx context.Context, port int, expectedState string,
return result.Storage, nil

case <-timer.C:
return nil, fmt.Errorf("%w after %s", ErrCallbackTimeout, callbackTimeout)
return nil, fmt.Errorf("%w after %s", ErrCallbackTimeout, cbTimeout)

case <-ctx.Done():
return nil, ctx.Err()
Expand Down
2 changes: 1 addition & 1 deletion callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func startCallbackServerAsync(
t.Helper()
ch := make(chan callbackServerResult, 1)
go func() {
storage, err := startCallbackServer(ctx, port, state, exchangeFn)
storage, err := startCallbackServer(ctx, port, state, defaultCallbackTimeout, exchangeFn)
ch <- callbackServerResult{storage, err}
}()
// Give the server a moment to bind.
Expand Down
Loading
Loading