diff --git a/free_session.go b/free_session.go index 4f8bece..3ad4902 100644 --- a/free_session.go +++ b/free_session.go @@ -81,6 +81,9 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { if err != nil { p.session = nil p.lastError = err.Error() + if isBannedErrorMessage(err.Error()) { + p.disabled = true + } } else if waitingErr := waitingRoomErrorFromSession(p.name, session, time.Now()); waitingErr != nil { p.lastError = waitingErr.Error() } else { diff --git a/models.go b/models.go index a1bd1b3..1abba4c 100644 --- a/models.go +++ b/models.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log" - "math/rand" "net/http" "regexp" "sort" @@ -202,7 +201,7 @@ func parseAllFreeModels(source string) map[string][]string { } // buildModelMapping creates the model→agent reverse mapping and deduplicated model list. -// When a model appears in multiple agents, one is chosen at random. +// When a model appears in multiple agents, pick the least-used agent to spread traffic. func buildModelMapping(agentModels map[string][]string) (map[string]string, []string) { modelAgents := make(map[string][]string) for agentID, models := range agentModels { @@ -213,10 +212,25 @@ func buildModelMapping(agentModels map[string][]string) (map[string]string, []st modelToAgent := make(map[string]string, len(modelAgents)) allModels := make([]string, 0, len(modelAgents)) - for model, agents := range modelAgents { - modelToAgent[model] = agents[rand.Intn(len(agents))] + for model := range modelAgents { allModels = append(allModels, model) } sort.Strings(allModels) + + agentUseCount := make(map[string]int, len(agentModels)) + for _, model := range allModels { + agents := append([]string(nil), modelAgents[model]...) + sort.Strings(agents) + chosen := agents[0] + bestCount := agentUseCount[chosen] + for _, agentID := range agents[1:] { + if count := agentUseCount[agentID]; count < bestCount { + chosen = agentID + bestCount = count + } + } + modelToAgent[model] = chosen + agentUseCount[chosen]++ + } return modelToAgent, allModels } diff --git a/run_manager.go b/run_manager.go index 97361fc..3010204 100644 --- a/run_manager.go +++ b/run_manager.go @@ -35,6 +35,7 @@ type tokenPool struct { sessionRefreshCh chan struct{} lastError string cooldownUntil time.Time + disabled bool } type managedRun struct { @@ -63,6 +64,7 @@ type tokenSnapshot struct { SessionPollAt time.Time `json:"session_poll_at,omitempty"` CooldownUntil time.Time `json:"cooldown_until,omitempty"` LastError string `json:"last_error,omitempty"` + Disabled bool `json:"disabled,omitempty"` } type runSnapshot struct { @@ -157,9 +159,15 @@ func (m *RunManager) prewarm(agentIDs []string) { defer cancel() for _, pool := range m.pools { + if pool.isDisabled() { + continue + } if _, err := pool.ensureSession(ctx); err != nil { m.logger.Printf("%s: free session prewarm failed: %v", pool.name, err) } + if pool.isDisabled() { + continue + } for _, agentID := range agentIDs { if err := pool.rotateAgent(ctx, agentID); err != nil { m.logger.Printf("%s: prewarm %s failed: %v", pool.name, agentID, err) @@ -247,6 +255,14 @@ func (m *RunManager) Snapshots() []tokenSnapshot { func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, error) { p.mu.Lock() + if p.disabled { + lastError := p.lastError + p.mu.Unlock() + if lastError == "" { + lastError = "token disabled" + } + return nil, errors.New(lastError) + } if now := time.Now(); now.Before(p.cooldownUntil) { cooldownUntil := p.cooldownUntil p.mu.Unlock() @@ -278,9 +294,15 @@ func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, err } func (p *tokenPool) maintain(ctx context.Context) error { + if p.isDisabled() { + return nil + } if _, err := p.ensureSession(ctx); err != nil { p.logger.Printf("%s: refresh free session failed: %v", p.name, err) } + if p.isDisabled() { + return nil + } p.mu.Lock() var toRotate []string @@ -334,6 +356,14 @@ func (p *tokenPool) shutdown(ctx context.Context) error { func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { p.mu.Lock() + if p.disabled { + lastError := p.lastError + p.mu.Unlock() + if lastError == "" { + lastError = "token disabled" + } + return errors.New(lastError) + } if now := time.Now(); now.Before(p.cooldownUntil) { cooldownUntil := p.cooldownUntil p.mu.Unlock() @@ -343,6 +373,10 @@ func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { runID, err := p.client.StartRun(ctx, p.token, agentID) if err != nil { + if isBannedErrorMessage(err.Error()) { + p.disable("upstream token banned") + return err + } p.mu.Lock() p.lastError = err.Error() p.mu.Unlock() @@ -467,6 +501,7 @@ func (p *tokenPool) snapshot() tokenSnapshot { DrainingRuns: len(p.draining), CooldownUntil: p.cooldownUntil, LastError: p.lastError, + Disabled: p.disabled, } if p.session != nil { snapshot.SessionStatus = string(p.session.status) @@ -487,3 +522,25 @@ func (p *tokenPool) snapshot() tokenSnapshot { } return snapshot } + +func (p *tokenPool) disable(reason string) { + p.mu.Lock() + defer p.mu.Unlock() + p.disabled = true + p.session = nil + p.cooldownUntil = time.Time{} + if reason != "" { + p.lastError = reason + } +} + +func (p *tokenPool) isDisabled() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.disabled +} + +func isBannedErrorMessage(message string) bool { + message = strings.ToLower(strings.TrimSpace(message)) + return strings.Contains(message, `"status":"banned"`) || strings.Contains(message, `"status": "banned"`) || strings.Contains(message, "status\":\"banned") +} diff --git a/server.go b/server.go index 9fad1fc..7a1fff3 100644 --- a/server.go +++ b/server.go @@ -264,7 +264,12 @@ func (s *Server) proxyChatRequest( return } - for attempt := 0; attempt < 2; attempt++ { + maxAttempts := len(s.cfg.AuthTokens) + 1 + if maxAttempts < 2 { + maxAttempts = 2 + } + + for attempt := 0; attempt < maxAttempts; attempt++ { lease, err := s.runs.Acquire(r.Context(), agentID) if err != nil { var waitingErr *waitingRoomError @@ -320,6 +325,21 @@ func (s *Server) proxyChatRequest( return } + message, _, code := extractUpstreamError(errorBody) + if isBannedErrorMessage(string(errorBody)) { + s.logger.Printf("%s: upstream token banned, disabling token", lease.pool.name) + lease.pool.disable("upstream token banned") + s.runs.Release(lease) + continue + } + if strings.TrimSpace(code) == "session_model_mismatch" { + s.logger.Printf("%s: session model mismatch on run %s, rotating run and refreshing session", lease.pool.name, lease.run.id) + lease.pool.invalidateSession(strings.TrimSpace(message)) + s.runs.Invalidate(lease, strings.TrimSpace(message)) + s.runs.Release(lease) + continue + } + if isSessionInvalid(resp.StatusCode, errorBody) { s.logger.Printf("%s: free session invalid, refreshing and retrying", lease.pool.name) lease.pool.invalidateSession(strings.TrimSpace(string(errorBody))) @@ -388,14 +408,9 @@ func isSessionInvalid(statusCode int, errorBody []byte) bool { if statusCode < 400 { return false } - var payload struct { - Error string `json:"error"` - } - if err := json.Unmarshal(errorBody, &payload); err != nil { - return false - } - switch strings.TrimSpace(payload.Error) { - case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired": + _, _, code := extractUpstreamError(errorBody) + switch strings.TrimSpace(code) { + case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired", "session_model_mismatch": return true default: return false