diff --git a/README.md b/README.md index ee1765c..c511e71 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Freebuff2API is an OpenAI-compatible proxy server for [Freebuff](https://freebuf ## Features - **OpenAI Compatible API** — Standard OpenAI endpoints; works with any compatible client out of the box. -- **Stealth Request Handling** — Dynamic, randomized client fingerprints that mimic official Freebuff SDK behavior. +- **Freebuff Session Compatibility** — Preserves the current Freebuff waiting-room/session contract, including model-bound session selection and OpenAI-compatible request metadata. - **Multi-Token Rotation** — Cycle through multiple auth tokens with automatic periodic rotation. - **HTTP Proxy Support** — Route all outbound traffic through a configurable upstream proxy. diff --git a/free_session.go b/free_session.go index 4f8bece..888ee15 100644 --- a/free_session.go +++ b/free_session.go @@ -28,6 +28,7 @@ const ( type freeSessionResponse struct { Status string `json:"status"` InstanceID string `json:"instanceId"` + Model string `json:"model"` Position int `json:"position"` QueueDepth int `json:"queueDepth"` QueuedAt string `json:"queuedAt"` @@ -41,6 +42,7 @@ type freeSessionResponse struct { type cachedSession struct { status sessionStatus instanceID string + model string expiresAt time.Time position int queueDepth int @@ -48,17 +50,25 @@ type cachedSession struct { retryAfter time.Duration } -func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { +func (p *tokenPool) ensureSession(ctx context.Context, model string) (string, error) { + model = strings.TrimSpace(model) for { p.mu.Lock() - if instanceID, ready := p.readySessionLocked(time.Now()); ready { + if instanceID, ready := p.readySessionLocked(time.Now(), model); ready { p.mu.Unlock() return instanceID, nil } - if waitingErr := waitingRoomErrorFromSession(p.name, p.session, time.Now()); waitingErr != nil { + if waitingErr := waitingRoomErrorFromSession(p.name, p.session, time.Now()); waitingErr != nil && p.sessionMatchesModelLocked(model) { p.mu.Unlock() return "", waitingErr } + if p.session != nil && !p.sessionMatchesModelLocked(model) { + p.mu.Unlock() + if err := p.prepareModel(ctx, model); err != nil { + return "", err + } + continue + } if ch := p.sessionRefreshCh; ch != nil { p.mu.Unlock() select { @@ -72,7 +82,7 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { p.sessionRefreshCh = ch p.mu.Unlock() - session, instanceID, err := p.refreshSession(ctx) + session, instanceID, err := p.refreshSession(ctx, model) p.mu.Lock() if session != nil { @@ -99,10 +109,13 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { } } -func (p *tokenPool) readySessionLocked(now time.Time) (string, bool) { +func (p *tokenPool) readySessionLocked(now time.Time, model string) (string, bool) { if p.session == nil { return "", false } + if !p.sessionMatchesModelLocked(model) { + return "", false + } switch p.session.status { case sessionStatusDisabled: return "", true @@ -117,7 +130,8 @@ func (p *tokenPool) readySessionLocked(now time.Time) (string, bool) { return "", false } -func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, error) { +func (p *tokenPool) refreshSession(ctx context.Context, model string) (*cachedSession, string, error) { + model = strings.TrimSpace(model) p.mu.Lock() current := p.session p.mu.Unlock() @@ -132,7 +146,7 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, return nil, "", fmt.Errorf("poll free session: %w", err) } } else { - state, err = p.client.CreateOrRefreshSession(ctx, p.token) + state, err = p.client.CreateOrRefreshSession(ctx, p.token, model) if err != nil { return nil, "", fmt.Errorf("start free session: %w", err) } @@ -141,7 +155,7 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, for { switch sessionStatus(strings.TrimSpace(state.Status)) { case sessionStatusDisabled: - return &cachedSession{status: sessionStatusDisabled}, "", nil + return &cachedSession{status: sessionStatusDisabled, model: model}, "", nil case sessionStatusActive: instanceID := strings.TrimSpace(state.InstanceID) if instanceID == "" { @@ -154,6 +168,7 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, return &cachedSession{ status: sessionStatusActive, instanceID: instanceID, + model: firstNonEmptyTrimmedString(strings.TrimSpace(state.Model), model), expiresAt: expiresAt, }, instanceID, nil case sessionStatusQueued: @@ -166,13 +181,14 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, return &cachedSession{ status: sessionStatusQueued, instanceID: instanceID, + model: firstNonEmptyTrimmedString(strings.TrimSpace(state.Model), model), position: maxInt(state.Position, 1), queueDepth: maxInt(state.QueueDepth, maxInt(state.Position, 1)), pollAt: time.Now().Add(delay), retryAfter: delay, }, "", nil case sessionStatusNone, sessionStatusEnded, sessionStatusSuperseded: - state, err = p.client.CreateOrRefreshSession(ctx, p.token) + state, err = p.client.CreateOrRefreshSession(ctx, p.token, model) if err != nil { return nil, "", fmt.Errorf("refresh free session: %w", err) } @@ -200,6 +216,92 @@ func (p *tokenPool) currentSessionInstanceID() string { return p.session.instanceID } +func (p *tokenPool) currentSessionModel() string { + p.mu.Lock() + defer p.mu.Unlock() + if p.session == nil { + return "" + } + return p.session.model +} + +func (p *tokenPool) sessionMatchesModelLocked(model string) bool { + if p.session == nil { + return false + } + model = strings.TrimSpace(model) + if model == "" || strings.TrimSpace(p.session.model) == "" { + return true + } + return p.session.model == model +} + +func (p *tokenPool) prepareModel(ctx context.Context, model string) error { + model = strings.TrimSpace(model) + if model == "" { + return nil + } + + p.mu.Lock() + currentModel := "" + if p.session != nil { + currentModel = strings.TrimSpace(p.session.model) + } + if currentModel == "" { + for _, run := range p.runs { + if strings.TrimSpace(run.model) != "" { + currentModel = strings.TrimSpace(run.model) + break + } + } + } + if currentModel == "" || currentModel == model { + p.mu.Unlock() + return nil + } + + for _, run := range p.runs { + if run.inflight > 0 { + p.mu.Unlock() + return fmt.Errorf("token is busy with model %s", currentModel) + } + } + for _, run := range p.draining { + if run.inflight > 0 { + p.mu.Unlock() + return fmt.Errorf("token is busy with model %s", currentModel) + } + } + + session := p.session + var allRuns []*managedRun + for _, run := range p.runs { + allRuns = append(allRuns, run) + } + allRuns = append(allRuns, p.draining...) + p.runs = make(map[string]*managedRun) + p.draining = nil + p.session = nil + p.lastError = "" + p.mu.Unlock() + + var errs []string + for _, run := range allRuns { + if err := p.client.FinishRun(ctx, p.token, run.id, run.requestCount); err != nil { + errs = append(errs, err.Error()) + } + } + if session != nil && session.status != sessionStatusDisabled && session.instanceID != "" { + if err := p.client.EndSession(ctx, p.token); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("switch token from model %s to %s: %s", currentModel, model, strings.Join(errs, "; ")) + } + return nil +} + func waitingRoomErrorFromSession(token string, session *cachedSession, now time.Time) *waitingRoomError { if session == nil || session.status != sessionStatusQueued { return nil @@ -286,12 +388,12 @@ func (p *tokenPool) endSession(ctx context.Context) error { return nil } -func (c *UpstreamClient) CreateOrRefreshSession(ctx context.Context, authToken string) (freeSessionResponse, error) { - return c.doSessionRequest(ctx, http.MethodPost, authToken, "") +func (c *UpstreamClient) CreateOrRefreshSession(ctx context.Context, authToken, model string) (freeSessionResponse, error) { + return c.doSessionRequest(ctx, http.MethodPost, authToken, "", model) } func (c *UpstreamClient) GetSession(ctx context.Context, authToken, instanceID string) (freeSessionResponse, error) { - return c.doSessionRequest(ctx, http.MethodGet, authToken, instanceID) + return c.doSessionRequest(ctx, http.MethodGet, authToken, instanceID, "") } func (c *UpstreamClient) EndSession(ctx context.Context, authToken string) error { @@ -324,7 +426,7 @@ func (c *UpstreamClient) EndSession(ctx context.Context, authToken string) error return nil } -func (c *UpstreamClient) doSessionRequest(ctx context.Context, method, authToken, instanceID string) (freeSessionResponse, error) { +func (c *UpstreamClient) doSessionRequest(ctx context.Context, method, authToken, instanceID, model string) (freeSessionResponse, error) { requestURL, err := url.JoinPath(c.baseURL, "/api/v1/freebuff/session") if err != nil { return freeSessionResponse{}, fmt.Errorf("build free session url: %w", err) @@ -344,6 +446,9 @@ func (c *UpstreamClient) doSessionRequest(ctx context.Context, method, authToken req.Header.Set("User-Agent", c.userAgent) if method == http.MethodPost { req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(model) != "" { + req.Header.Set("x-freebuff-model", strings.TrimSpace(model)) + } } if method == http.MethodGet && instanceID != "" { req.Header.Set("x-freebuff-instance-id", instanceID) @@ -412,3 +517,12 @@ func sleepWithContext(ctx context.Context, delay time.Duration) error { return nil } } + +func firstNonEmptyTrimmedString(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} diff --git a/run_manager.go b/run_manager.go index 97361fc..eebec6d 100644 --- a/run_manager.go +++ b/run_manager.go @@ -40,6 +40,7 @@ type tokenPool struct { type managedRun struct { id string agentID string + model string startedAt time.Time inflight int requestCount int @@ -55,6 +56,7 @@ type tokenSnapshot struct { Name string `json:"name"` Runs []runSnapshot `json:"runs"` DrainingRuns int `json:"draining_runs"` + SessionModel string `json:"session_model,omitempty"` SessionStatus string `json:"session_status,omitempty"` SessionInstanceID string `json:"session_instance_id,omitempty"` SessionExpiresAt time.Time `json:"session_expires_at,omitempty"` @@ -67,6 +69,7 @@ type tokenSnapshot struct { type runSnapshot struct { AgentID string `json:"agent_id"` + Model string `json:"model,omitempty"` RunID string `json:"run_id"` StartedAt time.Time `json:"started_at"` Inflight int `json:"inflight"` @@ -157,11 +160,8 @@ func (m *RunManager) prewarm(agentIDs []string) { defer cancel() for _, pool := range m.pools { - if _, err := pool.ensureSession(ctx); err != nil { - m.logger.Printf("%s: free session prewarm failed: %v", pool.name, err) - } for _, agentID := range agentIDs { - if err := pool.rotateAgent(ctx, agentID); err != nil { + if err := pool.rotateAgent(ctx, agentID, ""); err != nil { m.logger.Printf("%s: prewarm %s failed: %v", pool.name, agentID, err) } else { m.logger.Printf("%s: prewarmed %s", pool.name, agentID) @@ -180,7 +180,7 @@ func (m *RunManager) Close(ctx context.Context) { } } -func (m *RunManager) Acquire(ctx context.Context, agentID string) (*runLease, error) { +func (m *RunManager) Acquire(ctx context.Context, agentID, model string) (*runLease, error) { if len(m.pools) == 0 { return nil, errors.New("no auth tokens configured") } @@ -190,7 +190,7 @@ func (m *RunManager) Acquire(ctx context.Context, agentID string) (*runLease, er var waiting []*waitingRoomError for offset := 0; offset < len(m.pools); offset++ { pool := m.pools[(startIndex+offset)%len(m.pools)] - lease, err := pool.acquire(ctx, agentID) + lease, err := pool.acquire(ctx, agentID, model) if err == nil { return lease, nil } @@ -245,7 +245,11 @@ func (m *RunManager) Snapshots() []tokenSnapshot { return snapshots } -func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, error) { +func (p *tokenPool) acquire(ctx context.Context, agentID, model string) (*runLease, error) { + if err := p.prepareModel(ctx, model); err != nil { + return nil, err + } + p.mu.Lock() if now := time.Now(); now.Before(p.cooldownUntil) { cooldownUntil := p.cooldownUntil @@ -253,16 +257,16 @@ func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, err return nil, fmt.Errorf("token cooling down until %s", cooldownUntil.Format(time.RFC3339)) } run := p.runs[agentID] - needsRotate := run == nil || time.Since(run.startedAt) >= p.cfg.RotationInterval + needsRotate := run == nil || run.model != model || time.Since(run.startedAt) >= p.cfg.RotationInterval p.mu.Unlock() if needsRotate { - if err := p.rotateAgent(ctx, agentID); err != nil { + if err := p.rotateAgent(ctx, agentID, model); err != nil { return nil, err } } - if _, err := p.ensureSession(ctx); err != nil { + if _, err := p.ensureSession(ctx, model); err != nil { return nil, err } @@ -278,8 +282,10 @@ func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, err } func (p *tokenPool) maintain(ctx context.Context) error { - if _, err := p.ensureSession(ctx); err != nil { - p.logger.Printf("%s: refresh free session failed: %v", p.name, err) + if model := p.currentSessionModel(); model != "" { + if _, err := p.ensureSession(ctx, model); err != nil { + p.logger.Printf("%s: refresh free session failed: %v", p.name, err) + } } p.mu.Lock() @@ -293,7 +299,13 @@ func (p *tokenPool) maintain(ctx context.Context) error { p.mu.Unlock() for _, agentID := range toRotate { - if err := p.rotateAgent(ctx, agentID); err != nil { + model := "" + p.mu.Lock() + if run := p.runs[agentID]; run != nil { + model = run.model + } + p.mu.Unlock() + if err := p.rotateAgent(ctx, agentID, model); err != nil { p.logger.Printf("%s: rotate agent %s failed: %v", p.name, agentID, err) } } @@ -332,7 +344,7 @@ func (p *tokenPool) shutdown(ctx context.Context) error { return nil } -func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { +func (p *tokenPool) rotateAgent(ctx context.Context, agentID, model string) error { p.mu.Lock() if now := time.Now(); now.Before(p.cooldownUntil) { cooldownUntil := p.cooldownUntil @@ -354,6 +366,7 @@ func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { p.runs[agentID] = &managedRun{ id: runID, agentID: agentID, + model: model, startedAt: time.Now(), } p.lastError = "" @@ -469,6 +482,7 @@ func (p *tokenPool) snapshot() tokenSnapshot { LastError: p.lastError, } if p.session != nil { + snapshot.SessionModel = p.session.model snapshot.SessionStatus = string(p.session.status) snapshot.SessionInstanceID = p.session.instanceID snapshot.SessionExpiresAt = p.session.expiresAt @@ -479,6 +493,7 @@ func (p *tokenPool) snapshot() tokenSnapshot { for agentID, run := range p.runs { snapshot.Runs = append(snapshot.Runs, runSnapshot{ AgentID: agentID, + Model: run.model, RunID: run.id, StartedAt: run.startedAt, Inflight: run.inflight, diff --git a/server.go b/server.go index 9fad1fc..12505f7 100644 --- a/server.go +++ b/server.go @@ -264,8 +264,13 @@ func (s *Server) proxyChatRequest( return } - for attempt := 0; attempt < 2; attempt++ { - lease, err := s.runs.Acquire(r.Context(), agentID) + maxAttempts := len(s.cfg.AuthTokens) + 1 + if maxAttempts < 2 { + maxAttempts = 2 + } + + for attempt := 0; attempt < maxAttempts; attempt++ { + lease, err := s.runs.Acquire(r.Context(), agentID, requestedModel) if err != nil { var waitingErr *waitingRoomError if errors.As(err, &waitingErr) { @@ -281,7 +286,7 @@ func (s *Server) proxyChatRequest( s.logger.Printf("[%s] Routing request (model: %s) via run: %s", lease.pool.name, requestedModel, lease.run.id) - sessionInstanceID, err := lease.pool.ensureSession(r.Context()) + sessionInstanceID, err := lease.pool.ensureSession(r.Context(), requestedModel) if err != nil { s.runs.Release(lease) var waitingErr *waitingRoomError @@ -320,6 +325,15 @@ func (s *Server) proxyChatRequest( return } + message, _, code := extractUpstreamError(errorBody) + if strings.TrimSpace(code) == "session_model_mismatch" { + s.logger.Printf("%s: session model mismatch on run %s, rotating run and refreshing session", lease.pool.name, lease.run.id) + lease.pool.invalidateSession(strings.TrimSpace(message)) + s.runs.Invalidate(lease, strings.TrimSpace(message)) + s.runs.Release(lease) + continue + } + if isSessionInvalid(resp.StatusCode, errorBody) { s.logger.Printf("%s: free session invalid, refreshing and retrying", lease.pool.name) lease.pool.invalidateSession(strings.TrimSpace(string(errorBody))) @@ -388,14 +402,9 @@ func isSessionInvalid(statusCode int, errorBody []byte) bool { if statusCode < 400 { return false } - var payload struct { - Error string `json:"error"` - } - if err := json.Unmarshal(errorBody, &payload); err != nil { - return false - } - switch strings.TrimSpace(payload.Error) { - case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired": + _, _, code := extractUpstreamError(errorBody) + switch strings.TrimSpace(code) { + case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired", "session_model_mismatch": return true default: return false