diff --git a/README.md b/README.md index ee1765c..e7088eb 100644 --- a/README.md +++ b/README.md @@ -2,93 +2,106 @@ [English](README.md) | [简体中文](README_zh.md) -Freebuff2API is an OpenAI-compatible proxy server for [Freebuff](https://freebuff.com). It translates standard OpenAI API requests into Freebuff's backend format, allowing you to use Freebuff's free models with any OpenAI-compatible client, SDK, or CLI tool. +Freebuff2API is a compatibility-focused proxy for [Freebuff](https://freebuff.com). It translates client requests into the current Freebuff backend contract so you can expose a stable API to OpenAI-compatible clients, Claude-compatible clients, and tools that expect the OpenAI Responses API. ## Features -- **OpenAI Compatible API** — Standard OpenAI endpoints; works with any compatible client out of the box. -- **Stealth Request Handling** — Dynamic, randomized client fingerprints that mimic official Freebuff SDK behavior. -- **Multi-Token Rotation** — Cycle through multiple auth tokens with automatic periodic rotation. -- **HTTP Proxy Support** — Route all outbound traffic through a configurable upstream proxy. +- OpenAI-compatible `POST /v1/chat/completions` +- OpenAI-compatible `POST /v1/responses` +- Claude-compatible `POST /v1/messages` +- Claude-compatible `POST /v1/messages/count_tokens` +- `GET /v1/models` model discovery +- Freebuff waiting-room and model-bound session handling +- Stable retryable proxy errors such as `waiting_room_queued`, `session_switch_in_progress`, and `token_pool_unavailable` +- Automatic token disabling when upstream reports a banned token +- YAML/JSON config loading with runtime hot reload +- Token directory loading via `AUTH_TOKEN_DIR` +- Runtime diagnostics via `GET /healthz` and `GET /status` +- Optional outbound HTTP proxy support -## Getting Auth Tokens +## Auth Tokens -Freebuff2API requires one or more Freebuff **auth tokens**. There are two ways to obtain one: +Freebuff2API needs one or more Freebuff auth tokens. -### Method 1 — Web (Recommended) +### Method 1: Web -Visit **[https://freebuff.llm.pm](https://freebuff.llm.pm)**, log in with your Freebuff account, and your auth token will be displayed directly on the page. Copy it as your **AUTH_TOKENS** — no local installation required. +Visit **[https://freebuff.llm.pm](https://freebuff.llm.pm)**, sign in with your Freebuff account, and copy the displayed auth token. -### Method 2 — Freebuff CLI +### Method 2: Freebuff CLI -Install the Freebuff CLI: +Install the CLI: ```bash npm i -g freebuff ``` -Run `freebuff` in your terminal — on first launch it will guide you through login. - -After logging in, your token is saved to a local credentials file: +Run `freebuff` and finish the login flow. The token is then stored locally: | OS | Credentials Path | |---|---| | Windows | `C:\Users\\.config\manicode\credentials.json` | | Linux / macOS | `~/.config/manicode/credentials.json` | -The file looks like: +Example: ```json { "default": { - "id": "user_10293847", - "name": "Zhang San", - "email": "zhangsan@example.com", - "authToken": "fa82b5c1-e39d-4c7a-961f-d2b3c4e5f6a7", - ... + "authToken": "fa82b5c1-e39d-4c7a-961f-d2b3c4e5f6a7" } } ``` -Only the `authToken` value is needed — copy it as your **AUTH_TOKENS**. - -> **Tip:** Log in with multiple accounts and configure all their tokens for higher throughput. +Only the `authToken` value is required. ## Configuration -Configuration is managed via a JSON file and/or environment variables. The JSON keys and environment variable names are identical. By default the app looks for `config.json` in the working directory; use `-config` to specify another path. - -```json -{ - "LISTEN_ADDR": ":8080", - "UPSTREAM_BASE_URL": "https://codebuff.com", - "AUTH_TOKENS": ["eyJhb..."], - "ROTATION_INTERVAL": "6h", - "REQUEST_TIMEOUT": "15m", - "API_KEYS": [], - "HTTP_PROXY": "" -} +The server accepts YAML or JSON config files. By default it looks for `config.yaml`, then `config.yml`, then `config.json` in the working directory. You can also pass a path with `-config`. + +Example: + +```yaml +LISTEN_ADDR: ":8080" +UPSTREAM_BASE_URL: "https://www.codebuff.com" +AUTH_TOKENS: + - "token-1" + - "token-2" +AUTH_TOKEN_DIR: "tokens.d" +ROTATION_INTERVAL: "6h" +REQUEST_TIMEOUT: "15m" +API_KEYS: [] +HTTP_PROXY: "" ``` ### Reference | Key / Env Var | Description | |---|---| -| `LISTEN_ADDR` | Proxy listen address (default `:8080`) | -| `UPSTREAM_BASE_URL` | Freebuff backend URL (default `https://codebuff.com`) | -| `AUTH_TOKENS` | Freebuff auth tokens (JSON array or comma-separated env var) | -| `ROTATION_INTERVAL` | Run rotation interval (default `6h`) | -| `REQUEST_TIMEOUT` | Upstream request timeout (default `15m`) | -| `API_KEYS` | Client API keys for proxy auth (empty = open access) | -| `HTTP_PROXY` | HTTP proxy for outbound requests | +| `LISTEN_ADDR` | Proxy listen address. Default: `:8080` | +| `UPSTREAM_BASE_URL` | Upstream Freebuff backend URL. Default: `https://www.codebuff.com` | +| `AUTH_TOKENS` | Inline auth tokens. JSON array in files, comma-separated in env | +| `AUTH_TOKEN_DIR` | Optional directory of token files. Plain text, JSON, and YAML token blobs are supported | +| `ROTATION_INTERVAL` | Run rotation interval. Default: `6h` | +| `REQUEST_TIMEOUT` | Upstream request timeout. Default: `15m` | +| `API_KEYS` | Optional client-facing API keys. Empty means open access | +| `HTTP_PROXY` | Optional outbound HTTP proxy | + +Notes: + +- Environment variables provide startup defaults. +- If a config file is loaded, runtime reloads use the file as the source of truth. +- `LISTEN_ADDR` still requires a process restart because the HTTP listener is already bound. -Environment variables override JSON values when both are set. +## Runtime Status + +- `GET /healthz`: lightweight readiness summary +- `GET /status`: full token/session snapshot, active config summary, and available models ## Deployment ### Docker -Pre-built multi-arch images are available on GHCR: +Simple env-based run: ```bash docker run -d --name Freebuff2API \ @@ -97,6 +110,30 @@ docker run -d --name Freebuff2API \ ghcr.io/quorinex/freebuff2api:latest ``` +Recommended hot-reload setup: + +```bash +mkdir -p runtime/tokens.d +cat > runtime/config.yaml <<'EOF' +LISTEN_ADDR: ":8080" +UPSTREAM_BASE_URL: "https://www.codebuff.com" +AUTH_TOKEN_DIR: "/runtime/tokens.d" +ROTATION_INTERVAL: "6h" +REQUEST_TIMEOUT: "15m" +API_KEYS: [] +HTTP_PROXY: "" +EOF + +printf '%s\n' 'token-1' > runtime/tokens.d/token-1.txt +printf '%s\n' 'token-2' > runtime/tokens.d/token-2.txt + +docker run -d --name Freebuff2API \ + -p 8080:8080 \ + -v "$(pwd)/runtime:/runtime" \ + ghcr.io/quorinex/freebuff2api:latest \ + -config /runtime/config.yaml +``` + Build from source: ```bash @@ -106,24 +143,102 @@ docker run -d -p 8080:8080 -e AUTH_TOKENS="token1,token2" Freebuff2API ### Build from Source -**Requirements:** Go 1.23+ +Requirements: Go 1.23+ ```bash git clone https://github.com/Quorinex/Freebuff2API.git cd Freebuff2API go build -o Freebuff2API . -./Freebuff2API -config config.json +./Freebuff2API -config config.yaml +``` + +## Codex CLI + +Freebuff2API can be used as a custom provider for Codex CLI via the OpenAI `Responses API`. + +Add a dedicated profile to `~/.codex/config.toml`: + +```toml +[profiles.freebuff] +model = "your-model-id" +model_provider = "freebuff" +model_reasoning_effort = "high" +model_reasoning_summary = "none" +model_verbosity = "medium" +model_catalog_json = "C:\\Users\\\\.codex\\freebuff-model-catalog.json" + +[model_providers.freebuff] +name = "Freebuff" +base_url = "https://your-gateway.example/v1" +wire_api = "responses" +experimental_bearer_token = "your-client-api-key" +``` + +Create `~/.codex/freebuff-model-catalog.json` and register the models exposed by your gateway. At minimum, include the same model id you set in the profile. + +Codex CLI currently expects full model metadata for custom providers, not just a list of model ids. The most reliable approach is: + +1. Run `codex debug models` +2. Copy a model entry with similar capabilities +3. Replace fields such as `slug`, `display_name`, and any capability metadata that should differ for your gateway model +4. Save the resulting `models` array to `freebuff-model-catalog.json` + +Notes: + +- `base_url` should point to your gateway's `/v1` root. +- `wire_api` must be `responses`. +- A custom `model_catalog_json` avoids Codex CLI fallback metadata warnings for non-OpenAI model ids. +- If your server enforces `API_KEYS`, replace `experimental_bearer_token` with a real client key. +- Keep the profile `model` and the catalog entry `slug` in sync with whatever model ids your gateway currently exposes. + +Then launch Codex with: + +```bash +codex -p freebuff +``` + +## Claude Code + +Freebuff2API can also be used as a Claude Code gateway through the Anthropic-compatible endpoints. + +Example `~/.claude/settings.json`: + +```json +{ + "$schema": "https://json.schemastore.org/claude-code-settings.json", + "env": { + "ANTHROPIC_API_KEY": "your-client-api-key", + "ANTHROPIC_BASE_URL": "https://your-gateway.example", + "ANTHROPIC_DEFAULT_SONNET_MODEL": "your-sonnet-model-id", + "ANTHROPIC_DEFAULT_SONNET_MODEL_NAME": "Sonnet via gateway", + "ANTHROPIC_DEFAULT_OPUS_MODEL": "your-opus-model-id", + "ANTHROPIC_DEFAULT_OPUS_MODEL_NAME": "Opus via gateway", + "ANTHROPIC_DEFAULT_HAIKU_MODEL": "your-haiku-model-id", + "ANTHROPIC_DEFAULT_HAIKU_MODEL_NAME": "Haiku via gateway", + "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1", + "ENABLE_TOOL_SEARCH": "true", + "NO_PROXY": "localhost" + }, + "permissions": { + "defaultMode": "bypassPermissions", + "skipDangerousModePermissionPrompt": true + }, + "effortLevel": "high" +} ``` -## Links +Notes: -- [linux.do](https://linux.do) +- `ANTHROPIC_BASE_URL` should be the gateway root and should not include `/v1`. +- Map the `ANTHROPIC_DEFAULT_*_MODEL` variables to whatever model ids your gateway currently exposes. +- Keep `skipDangerousModePermissionPrompt` inside `permissions`; the top-level key is unnecessary. +- If your gateway requires client auth, use a real key instead of the placeholder value. ## Disclaimer -This project has no official affiliation with OpenAI, Codebuff, or Freebuff. All related trademarks and copyrights belong to their respective owners. +This project is not affiliated with OpenAI, Codebuff, or Freebuff. All related trademarks belong to their respective owners. -All contents within this repository are provided solely for communication, experimentation, and learning, and do not constitute production-ready services or professional advice. This project is provided on an "As-Is" basis, and users must use it at their own risk. The author assumes no liability for any direct or indirect damages resulting from the use, modification, or distribution of this project, nor provides any warranties of any kind, express or implied. +This repository is provided for communication, experimentation, and learning. It is not production advice and is provided on an "as-is" basis. Use it at your own risk. ## License diff --git a/README_zh.md b/README_zh.md index 8c84b39..f03f3b6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -2,93 +2,106 @@ [English](README.md) | [简体中文](README_zh.md) -Freebuff2API 是 [Freebuff](https://freebuff.com) 的 OpenAI 兼容代理服务器。本项目将标准 OpenAI API 请求转化为 Freebuff 后端格式,让你能在任何 OpenAI 兼容客户端、SDK 或命令行工具中直接使用 Freebuff 的免费模型。 - -## 核心特性 - -- **OpenAI 兼容 API** — 标准 OpenAI 端点,开箱即用,支持任意兼容客户端。 -- **高隐匿性请求处理** — 动态随机客户端特征标识,模拟官方 Freebuff SDK 行为模式。 -- **多 Token 轮换** — 支持多个认证 Token,内置定期自动轮换机制。 -- **HTTP 代理支持** — 可为所有外部请求配置上游 HTTP 代理。 +Freebuff2API 是一个以兼容性和使用体验为重点的 [Freebuff](https://freebuff.com) 代理服务。它会把客户端请求转换成当前 Freebuff 后端要求的协议格式,对外提供稳定的 OpenAI 兼容接口、Claude 兼容接口,以及 OpenAI Responses API 兼容接口。 + +## 功能特性 + +- OpenAI 兼容 `POST /v1/chat/completions` +- OpenAI 兼容 `POST /v1/responses` +- Claude 兼容 `POST /v1/messages` +- Claude 兼容 `POST /v1/messages/count_tokens` +- `GET /v1/models` 模型发现接口 +- 兼容 Freebuff 当前 waiting-room 和按模型绑定的 session 协议 +- 返回稳定的可重试错误码,例如 `waiting_room_queued`、`session_switch_in_progress`、`token_pool_unavailable` +- 上游返回 banned token 时自动禁用对应 token +- 支持 YAML / JSON 配置文件和运行时热加载 +- 支持通过 `AUTH_TOKEN_DIR` 从目录加载 token +- 提供 `GET /healthz` 和 `GET /status` 运行状态接口 +- 支持上游 HTTP 代理 ## 获取 Auth Token -Freebuff2API 需要至少一个 Freebuff **Auth Token**。目前有以下两种获取方式: +Freebuff2API 需要一个或多个 Freebuff auth token。 -### 方式一 — 网页获取(推荐) +### 方式一:网页获取 -访问 **[https://freebuff.llm.pm](https://freebuff.llm.pm)**,使用你的 Freebuff 账号登录后,页面会直接显示你的 Auth Token。复制该值即可作为 **AUTH_TOKENS** 使用,无需在本地安装任何工具。 +访问 **[https://freebuff.llm.pm](https://freebuff.llm.pm)**,登录你的 Freebuff 账号后,复制页面展示的 auth token。 -### 方式二 — Freebuff CLI +### 方式二:Freebuff CLI -安装 Freebuff CLI 并完成登录: +安装 CLI: ```bash npm i -g freebuff ``` -安装完成后,在终端执行 `freebuff`,首次启动时会自动引导你完成登录。 - -登录后,Token 会自动保存到本地凭证文件中: +运行 `freebuff` 并完成登录流程。登录后 token 会保存在本地凭证文件中: -| 系统 | 凭证文件路径 | +| 系统 | 凭证路径 | |---|---| -| Windows | `C:\Users\<用户名>\.config\manicode\credentials.json` | +| Windows | `C:\Users\\.config\manicode\credentials.json` | | Linux / macOS | `~/.config/manicode/credentials.json` | -文件结构如下: +示例: ```json { "default": { - "id": "user_10293847", - "name": "张三", - "email": "zhangsan@example.com", - "authToken": "fa82b5c1-e39d-4c7a-961f-d2b3c4e5f6a7", - ... + "authToken": "fa82b5c1-e39d-4c7a-961f-d2b3c4e5f6a7" } } ``` -将 `authToken` 的值复制出来,即为所需的 **AUTH_TOKENS**。 +只需要取出 `authToken` 的值即可。 -> **提示:** 可登录多个账号并配置所有 Token,以提升并发吞吐量。 +## 配置说明 -## 配置指南 +程序支持 YAML 或 JSON 配置文件。默认会按顺序查找当前目录下的 `config.yaml`、`config.yml`、`config.json`。也可以通过 `-config` 显式指定路径。 -支持 JSON 文件和环境变量两种配置方式。JSON 属性名与环境变量名一致。默认在当前目录查找 `config.json`,可通过 `-config` 参数指定其他路径。 +示例: -```json -{ - "LISTEN_ADDR": ":8080", - "UPSTREAM_BASE_URL": "https://codebuff.com", - "AUTH_TOKENS": ["token"], - "ROTATION_INTERVAL": "6h", - "REQUEST_TIMEOUT": "15m", - "API_KEYS": [], - "HTTP_PROXY": "" -} +```yaml +LISTEN_ADDR: ":8080" +UPSTREAM_BASE_URL: "https://www.codebuff.com" +AUTH_TOKENS: + - "token-1" + - "token-2" +AUTH_TOKEN_DIR: "tokens.d" +ROTATION_INTERVAL: "6h" +REQUEST_TIMEOUT: "15m" +API_KEYS: [] +HTTP_PROXY: "" ``` -### 配置参考 +### 配置项 -| 属性 / 环境变量 | 说明 | +| 配置项 / 环境变量 | 说明 | |---|---| -| `LISTEN_ADDR` | 代理监听地址(默认 `:8080`) | -| `UPSTREAM_BASE_URL` | Freebuff 后端地址(默认 `https://codebuff.com`) | -| `AUTH_TOKENS` | Freebuff Auth Token(JSON 数组或逗号分隔的环境变量) | -| `ROTATION_INTERVAL` | Run 自动轮换间隔(默认 `6h`) | -| `REQUEST_TIMEOUT` | 上游请求超时时间(默认 `15m`) | -| `API_KEYS` | 客户端鉴权 API Key(留空则无需鉴权) | -| `HTTP_PROXY` | 上游 HTTP 代理地址 | +| `LISTEN_ADDR` | 服务监听地址,默认 `:8080` | +| `UPSTREAM_BASE_URL` | 上游 Freebuff 地址,默认 `https://www.codebuff.com` | +| `AUTH_TOKENS` | 直接写在配置中的 token;文件中是数组,环境变量中用逗号分隔 | +| `AUTH_TOKEN_DIR` | 可选 token 目录,支持纯文本、JSON、YAML 三种文件格式 | +| `ROTATION_INTERVAL` | run 轮换间隔,默认 `6h` | +| `REQUEST_TIMEOUT` | 上游请求超时时间,默认 `15m` | +| `API_KEYS` | 对外暴露给客户端的 API Key;留空表示不鉴权 | +| `HTTP_PROXY` | 可选的上游 HTTP 代理 | + +补充说明: + +- 环境变量用于提供启动时默认值。 +- 如果加载了配置文件,运行时热更新会以配置文件内容为准。 +- `LISTEN_ADDR` 修改后仍然需要重启进程,因为监听端口已经绑定。 -同时设置时,环境变量优先于 JSON 配置文件。 +## 运行状态接口 -## 部署运行 +- `GET /healthz`:轻量级健康摘要 +- `GET /status`:完整 token / session 状态、当前配置摘要、可用模型列表 -### Docker 部署 +## 部署方式 -预构建多架构镜像已发布至 GHCR: +### Docker + +最简单的环境变量启动方式: ```bash docker run -d --name Freebuff2API \ @@ -97,7 +110,31 @@ docker run -d --name Freebuff2API \ ghcr.io/quorinex/freebuff2api:latest ``` -手动构建: +推荐的热加载目录挂载方式: + +```bash +mkdir -p runtime/tokens.d +cat > runtime/config.yaml <<'EOF' +LISTEN_ADDR: ":8080" +UPSTREAM_BASE_URL: "https://www.codebuff.com" +AUTH_TOKEN_DIR: "/runtime/tokens.d" +ROTATION_INTERVAL: "6h" +REQUEST_TIMEOUT: "15m" +API_KEYS: [] +HTTP_PROXY: "" +EOF + +printf '%s\n' 'token-1' > runtime/tokens.d/token-1.txt +printf '%s\n' 'token-2' > runtime/tokens.d/token-2.txt + +docker run -d --name Freebuff2API \ + -p 8080:8080 \ + -v "$(pwd)/runtime:/runtime" \ + ghcr.io/quorinex/freebuff2api:latest \ + -config /runtime/config.yaml +``` + +从源码构建镜像: ```bash docker build -t Freebuff2API . @@ -106,25 +143,103 @@ docker run -d -p 8080:8080 -e AUTH_TOKENS="token1,token2" Freebuff2API ### 源码编译 -**环境要求:** Go 1.23+ +要求:Go 1.23+ ```bash git clone https://github.com/Quorinex/Freebuff2API.git cd Freebuff2API go build -o Freebuff2API . -./Freebuff2API -config config.json +./Freebuff2API -config config.yaml +``` + +## Codex CLI 配置 + +Freebuff2API 可以通过 OpenAI `Responses API` 作为 Codex CLI 的自定义 provider 使用。 + +在 `~/.codex/config.toml` 中增加一个独立 profile: + +```toml +[profiles.freebuff] +model = "your-model-id" +model_provider = "freebuff" +model_reasoning_effort = "high" +model_reasoning_summary = "none" +model_verbosity = "medium" +model_catalog_json = "C:\\Users\\\\.codex\\freebuff-model-catalog.json" + +[model_providers.freebuff] +name = "Freebuff" +base_url = "https://your-gateway.example/v1" +wire_api = "responses" +experimental_bearer_token = "your-client-api-key" +``` + +同时创建 `~/.codex/freebuff-model-catalog.json`,把网关当前暴露出来的模型写进去。至少要包含 profile 里设置的同一个模型 id。 + +目前 Codex CLI 对自定义 provider 的 model catalog 要求是完整 metadata,不只是模型 id 列表。更稳的做法是: + +1. 运行 `codex debug models` +2. 复制一个能力相近的模型条目 +3. 替换其中的 `slug`、`display_name` 以及需要调整的能力字段 +4. 把生成后的 `models` 数组保存为 `freebuff-model-catalog.json` + +说明: + +- `base_url` 要写到网关的 `/v1` +- `wire_api` 必须是 `responses` +- `model_catalog_json` 用来给非 OpenAI 官方模型补 metadata +- 如果服务端启用了 `API_KEYS`,把 `experimental_bearer_token` 换成真实的客户端 key +- profile 里的 `model` 和 catalog 里的 `slug` 必须始终与网关当前实际暴露的模型 id 保持一致 + +启动方式: + +```bash +codex -p freebuff +``` + +## Claude Code 配置 + +Freebuff2API 也可以作为 Claude Code 的网关,通过 Anthropic 兼容接口提供服务。 + +`~/.claude/settings.json` 示例: + +```json +{ + "$schema": "https://json.schemastore.org/claude-code-settings.json", + "env": { + "ANTHROPIC_API_KEY": "your-client-api-key", + "ANTHROPIC_BASE_URL": "https://your-gateway.example", + "ANTHROPIC_DEFAULT_SONNET_MODEL": "your-sonnet-model-id", + "ANTHROPIC_DEFAULT_SONNET_MODEL_NAME": "Sonnet via gateway", + "ANTHROPIC_DEFAULT_OPUS_MODEL": "your-opus-model-id", + "ANTHROPIC_DEFAULT_OPUS_MODEL_NAME": "Opus via gateway", + "ANTHROPIC_DEFAULT_HAIKU_MODEL": "your-haiku-model-id", + "ANTHROPIC_DEFAULT_HAIKU_MODEL_NAME": "Haiku via gateway", + "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1", + "ENABLE_TOOL_SEARCH": "true", + "NO_PROXY": "localhost" + }, + "permissions": { + "defaultMode": "bypassPermissions", + "skipDangerousModePermissionPrompt": true + }, + "effortLevel": "high" +} ``` -## 友情链接 +说明: -- [linux.do](https://linux.do) +- `ANTHROPIC_BASE_URL` 要写网关根地址,不要带 `/v1` +- `ANTHROPIC_DEFAULT_*_MODEL` 要映射到网关当前实际暴露的模型 id +- `skipDangerousModePermissionPrompt` 只需要保留在 `permissions` 里,不要再写一份顶层字段 +- 如果网关启用了客户端鉴权,请把占位值换成真实的 key ## 免责声明 -本项目与 OpenAI、Codebuff 或 Freebuff 无任何官方关联,相关商标和版权均归其各自所有者所有。 +本项目与 OpenAI、Codebuff、Freebuff 没有任何官方关联,相关商标归各自所有者所有。 -本仓库的所有内容仅供交流、实验和学习使用,不构成任何生产环境服务或专业建议。本项目按“原样(As-Is)”提供,使用者需自行承担使用风险。作者不对因使用、修改或分发本项目而导致的任何直接或间接损失承担责任,亦不提供任何形式的明示或暗示保证。 +本仓库仅用于交流、实验和学习,不构成生产建议。本项目按 “as-is” 方式提供,使用风险由使用者自行承担。 -## 开源协议 +## License MIT diff --git a/anthropic.go b/anthropic.go index 2793050..38aecf5 100644 --- a/anthropic.go +++ b/anthropic.go @@ -57,7 +57,9 @@ func convertClaudeMessagesRequestToOpenAI(body []byte) (map[string]any, string, } if reasoningEffort, ok := mapClaudeThinkingToReasoningEffort(root); ok { - out["reasoning_effort"] = reasoningEffort + if normalized, normalizedOK := normalizeReasoningEffort(reasoningEffort); normalizedOK { + out["reasoning_effort"] = normalized + } } messages := make([]any, 0, 8) @@ -1356,14 +1358,18 @@ func mapOpenAIFinishReasonToClaude(reason string) string { func writeClaudePassthroughError(w http.ResponseWriter, statusCode int, body []byte) { trimmed := bytes.TrimSpace(body) if len(trimmed) > 0 && json.Valid(trimmed) { - message, errorType, _ := extractUpstreamError(trimmed) - writeClaudeError(w, statusCode, message, normalizeClaudeErrorType(statusCode, errorType)) + message, errorType, code := extractUpstreamError(trimmed) + writeClaudeErrorDetailed(w, statusCode, message, normalizeClaudeErrorType(statusCode, errorType), code) return } - writeClaudeError(w, statusCode, strings.TrimSpace(string(trimmed)), normalizeClaudeErrorType(statusCode, "")) + writeClaudeErrorDetailed(w, statusCode, strings.TrimSpace(string(trimmed)), normalizeClaudeErrorType(statusCode, ""), "") } func writeClaudeError(w http.ResponseWriter, statusCode int, message, errorType string) { + writeClaudeErrorDetailed(w, statusCode, message, errorType, "") +} + +func writeClaudeErrorDetailed(w http.ResponseWriter, statusCode int, message, errorType, code string) { if strings.TrimSpace(message) == "" { message = http.StatusText(statusCode) } @@ -1371,13 +1377,17 @@ func writeClaudeError(w http.ResponseWriter, statusCode int, message, errorType errorType = normalizeClaudeErrorType(statusCode, "") } - writeJSON(w, statusCode, map[string]any{ + payload := map[string]any{ "type": "error", "error": map[string]any{ "type": errorType, "message": message, }, - }) + } + if strings.TrimSpace(code) != "" { + payload["error"].(map[string]any)["code"] = strings.TrimSpace(code) + } + writeJSON(w, statusCode, payload) } func normalizeClaudeErrorType(statusCode int, upstreamType string) string { diff --git a/config.example.json b/config.example.json index 776ac58..98d5016 100644 --- a/config.example.json +++ b/config.example.json @@ -2,6 +2,7 @@ "LISTEN_ADDR": ":8080", "UPSTREAM_BASE_URL": "https://www.codebuff.com", "AUTH_TOKENS": [], + "AUTH_TOKEN_DIR": "", "ROTATION_INTERVAL": "6h", "REQUEST_TIMEOUT": "15m", "API_KEYS": [], diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..53b4105 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,8 @@ +LISTEN_ADDR: ":8080" +UPSTREAM_BASE_URL: "https://www.codebuff.com" +AUTH_TOKENS: [] +AUTH_TOKEN_DIR: "" +ROTATION_INTERVAL: "6h" +REQUEST_TIMEOUT: "15m" +API_KEYS: [] +HTTP_PROXY: "" diff --git a/config.go b/config.go index 856c657..65ceebf 100644 --- a/config.go +++ b/config.go @@ -9,28 +9,80 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "time" + + "gopkg.in/yaml.v3" ) type Config struct { ListenAddr string UpstreamBaseURL string AuthTokens []string + AuthTokenDir string RotationInterval time.Duration RequestTimeout time.Duration UserAgent string APIKeys []string HTTPProxy string + ConfigPath string + ConfigFormat string + LoadedAt time.Time } type rawConfig struct { - ListenAddr string `json:"LISTEN_ADDR"` - UpstreamBaseURL string `json:"UPSTREAM_BASE_URL"` - AuthTokens []string `json:"AUTH_TOKENS"` - RotationInterval string `json:"ROTATION_INTERVAL"` - RequestTimeout string `json:"REQUEST_TIMEOUT"` - APIKeys []string `json:"API_KEYS"` - HTTPProxy string `json:"HTTP_PROXY"` + ListenAddr string `json:"LISTEN_ADDR" yaml:"LISTEN_ADDR"` + UpstreamBaseURL string `json:"UPSTREAM_BASE_URL" yaml:"UPSTREAM_BASE_URL"` + AuthTokens []string `json:"AUTH_TOKENS" yaml:"AUTH_TOKENS"` + AuthTokenDir string `json:"AUTH_TOKEN_DIR" yaml:"AUTH_TOKEN_DIR"` + RotationInterval string `json:"ROTATION_INTERVAL" yaml:"ROTATION_INTERVAL"` + RequestTimeout string `json:"REQUEST_TIMEOUT" yaml:"REQUEST_TIMEOUT"` + APIKeys []string `json:"API_KEYS" yaml:"API_KEYS"` + HTTPProxy string `json:"HTTP_PROXY" yaml:"HTTP_PROXY"` +} + +type ConfigStore struct { + current atomic.Value +} + +func NewConfigStore(cfg Config) *ConfigStore { + store := &ConfigStore{} + store.Update(cfg) + return store +} + +func (s *ConfigStore) Current() Config { + value := s.current.Load() + if value == nil { + return Config{} + } + return value.(Config) +} + +func (s *ConfigStore) Update(cfg Config) { + s.current.Store(cfg) +} + +func resolveConfigPath(configPath string) (string, error) { + if strings.TrimSpace(configPath) != "" { + resolved, err := filepath.Abs(strings.TrimSpace(configPath)) + if err != nil { + return "", fmt.Errorf("resolve config path: %w", err) + } + return resolved, nil + } + + for _, candidate := range []string{"config.yaml", "config.yml", "config.json"} { + if _, err := os.Stat(candidate); err == nil { + resolved, err := filepath.Abs(candidate) + if err != nil { + return "", fmt.Errorf("resolve config path: %w", err) + } + return resolved, nil + } + } + + return "", nil } func loadConfig(configPath string) (Config, error) { @@ -39,14 +91,6 @@ func loadConfig(configPath string) (Config, error) { return Config{}, err } - overrideString(&cfg.ListenAddr, "LISTEN_ADDR") - overrideString(&cfg.UpstreamBaseURL, "UPSTREAM_BASE_URL") - overrideString(&cfg.RotationInterval, "ROTATION_INTERVAL") - overrideString(&cfg.RequestTimeout, "REQUEST_TIMEOUT") - overrideCSV(&cfg.AuthTokens, "AUTH_TOKENS") - overrideCSV(&cfg.APIKeys, "API_KEYS") - overrideString(&cfg.HTTPProxy, "HTTP_PROXY") - rotationInterval, err := time.ParseDuration(strings.TrimSpace(cfg.RotationInterval)) if err != nil { return Config{}, fmt.Errorf("parse rotation interval: %w", err) @@ -57,15 +101,37 @@ func loadConfig(configPath string) (Config, error) { return Config{}, fmt.Errorf("parse request timeout: %w", err) } + authTokenDir := strings.TrimSpace(cfg.AuthTokenDir) + if authTokenDir != "" && !filepath.IsAbs(authTokenDir) { + baseDir := "." + if strings.TrimSpace(configPath) != "" { + baseDir = filepath.Dir(configPath) + } + authTokenDir = filepath.Clean(filepath.Join(baseDir, authTokenDir)) + } + + authTokens := dedupeStrings(cfg.AuthTokens) + if tokenDir := authTokenDir; tokenDir != "" { + tokensFromDir, err := loadAuthTokensFromDir(tokenDir) + if err != nil { + return Config{}, fmt.Errorf("load auth tokens from dir: %w", err) + } + authTokens = dedupeStrings(append(authTokens, tokensFromDir...)) + } + finalCfg := Config{ ListenAddr: strings.TrimSpace(cfg.ListenAddr), UpstreamBaseURL: normalizeUpstreamBaseURL(cfg.UpstreamBaseURL), - AuthTokens: dedupeStrings(cfg.AuthTokens), + AuthTokens: authTokens, + AuthTokenDir: authTokenDir, RotationInterval: rotationInterval, RequestTimeout: requestTimeout, UserAgent: generateUserAgent(), APIKeys: dedupeStrings(cfg.APIKeys), HTTPProxy: strings.TrimSpace(cfg.HTTPProxy), + ConfigPath: strings.TrimSpace(configPath), + ConfigFormat: configFormat(configPath), + LoadedAt: time.Now().UTC(), } switch { @@ -110,15 +176,36 @@ func loadRawConfig(configPath string) (rawConfig, error) { RequestTimeout: "15m", } - if configPath != "" { - path, err := filepath.Abs(configPath) + applyEnvConfig(&cfg) + + if strings.TrimSpace(configPath) != "" { + overlay, err := parseConfigFile(configPath) if err != nil { - return rawConfig{}, fmt.Errorf("resolve config path: %w", err) + return rawConfig{}, err } - data, err := os.ReadFile(path) - if err != nil { - return rawConfig{}, fmt.Errorf("read config file: %w", err) + mergeRawConfig(&cfg, overlay) + } + + return cfg, nil +} + +func parseConfigFile(configPath string) (rawConfig, error) { + path, err := filepath.Abs(configPath) + if err != nil { + return rawConfig{}, fmt.Errorf("resolve config path: %w", err) + } + data, err := os.ReadFile(path) + if err != nil { + return rawConfig{}, fmt.Errorf("read config file: %w", err) + } + + var cfg rawConfig + switch strings.ToLower(filepath.Ext(path)) { + case ".yaml", ".yml": + if err := yaml.Unmarshal(data, &cfg); err != nil { + return rawConfig{}, fmt.Errorf("parse config file: %w", err) } + default: if err := json.Unmarshal(data, &cfg); err != nil { return rawConfig{}, fmt.Errorf("parse config file: %w", err) } @@ -127,6 +214,161 @@ func loadRawConfig(configPath string) (rawConfig, error) { return cfg, nil } +func mergeRawConfig(dst *rawConfig, src rawConfig) { + if strings.TrimSpace(src.ListenAddr) != "" { + dst.ListenAddr = src.ListenAddr + } + if strings.TrimSpace(src.UpstreamBaseURL) != "" { + dst.UpstreamBaseURL = src.UpstreamBaseURL + } + if len(src.AuthTokens) > 0 { + dst.AuthTokens = src.AuthTokens + } + if strings.TrimSpace(src.AuthTokenDir) != "" { + dst.AuthTokenDir = src.AuthTokenDir + } + if strings.TrimSpace(src.RotationInterval) != "" { + dst.RotationInterval = src.RotationInterval + } + if strings.TrimSpace(src.RequestTimeout) != "" { + dst.RequestTimeout = src.RequestTimeout + } + if len(src.APIKeys) > 0 { + dst.APIKeys = src.APIKeys + } + if strings.TrimSpace(src.HTTPProxy) != "" { + dst.HTTPProxy = src.HTTPProxy + } +} + +func applyEnvConfig(cfg *rawConfig) { + overrideString(&cfg.ListenAddr, "LISTEN_ADDR") + overrideString(&cfg.UpstreamBaseURL, "UPSTREAM_BASE_URL") + overrideString(&cfg.AuthTokenDir, "AUTH_TOKEN_DIR") + overrideString(&cfg.RotationInterval, "ROTATION_INTERVAL") + overrideString(&cfg.RequestTimeout, "REQUEST_TIMEOUT") + overrideCSV(&cfg.AuthTokens, "AUTH_TOKENS") + overrideCSV(&cfg.APIKeys, "API_KEYS") + overrideString(&cfg.HTTPProxy, "HTTP_PROXY") +} + +func loadAuthTokensFromDir(dir string) ([]string, error) { + dir = strings.TrimSpace(dir) + if dir == "" { + return nil, nil + } + resolved, err := filepath.Abs(dir) + if err != nil { + return nil, fmt.Errorf("resolve auth token dir: %w", err) + } + entries, err := os.ReadDir(resolved) + if err != nil { + return nil, fmt.Errorf("read auth token dir: %w", err) + } + + tokens := make([]string, 0, len(entries)) + for _, entry := range entries { + if entry.IsDir() { + continue + } + path := filepath.Join(resolved, entry.Name()) + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read auth token file %s: %w", path, err) + } + tokens = append(tokens, extractTokensFromBlob(path, data)...) + } + + return dedupeStrings(tokens), nil +} + +func extractTokensFromBlob(path string, data []byte) []string { + var decoded any + ext := strings.ToLower(filepath.Ext(path)) + + switch ext { + case ".yaml", ".yml": + if err := yaml.Unmarshal(data, &decoded); err == nil { + if tokens := collectAuthTokens(decoded); len(tokens) > 0 { + return dedupeStrings(tokens) + } + } + default: + if err := json.Unmarshal(data, &decoded); err == nil { + if tokens := collectAuthTokens(decoded); len(tokens) > 0 { + return dedupeStrings(tokens) + } + } + if err := yaml.Unmarshal(data, &decoded); err == nil { + if tokens := collectAuthTokens(decoded); len(tokens) > 0 { + return dedupeStrings(tokens) + } + } + } + + return splitList(string(data)) +} + +func collectAuthTokens(value any) []string { + var tokens []string + collectAuthTokensInto(value, "", &tokens) + return dedupeStrings(tokens) +} + +func collectAuthTokensInto(value any, key string, tokens *[]string) { + switch typed := value.(type) { + case map[string]any: + for childKey, childValue := range typed { + collectAuthTokensInto(childValue, childKey, tokens) + } + case map[any]any: + for rawKey, childValue := range typed { + collectAuthTokensInto(childValue, fmt.Sprint(rawKey), tokens) + } + case []any: + for _, childValue := range typed { + collectAuthTokensInto(childValue, key, tokens) + } + case []string: + if isAuthTokenListKey(key) { + *tokens = append(*tokens, compactStrings(typed)...) + } + case string: + if isAuthTokenScalarKey(key) { + *tokens = append(*tokens, strings.TrimSpace(typed)) + } + } +} + +func isAuthTokenScalarKey(key string) bool { + switch strings.ToLower(strings.TrimSpace(key)) { + case "authtoken", "auth_token", "token": + return true + default: + return false + } +} + +func isAuthTokenListKey(key string) bool { + switch strings.ToLower(strings.TrimSpace(key)) { + case "authtokens", "auth_tokens", "tokens": + return true + default: + return false + } +} + +func configFormat(configPath string) string { + switch strings.ToLower(filepath.Ext(strings.TrimSpace(configPath))) { + case ".yaml", ".yml": + return "yaml" + case ".json": + return "json" + default: + return "" + } +} + func overrideString(target *string, envName string) { if value := strings.TrimSpace(os.Getenv(envName)); value != "" { *target = value @@ -187,7 +429,7 @@ func generateUserAgent() string { } // generateClientSessionId generates a per-request session ID matching the -// official SDK: Math.random().toString(36).substring(2, 15) — a ~13-char +// official SDK: Math.random().toString(36).substring(2, 15) -> a ~13-char // base-36 alphanumeric string. func generateClientSessionId() string { buf := make([]byte, 10) diff --git a/config_runtime.go b/config_runtime.go new file mode 100644 index 0000000..58a281b --- /dev/null +++ b/config_runtime.go @@ -0,0 +1,192 @@ +package main + +import ( + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/fsnotify/fsnotify" +) + +const configReloadPollInterval = 15 * time.Second + +type ConfigReloader struct { + configPath string + logger *log.Logger + apply func(Config) + + mu sync.Mutex + current Config + stopCh chan struct{} + wg sync.WaitGroup +} + +func NewConfigReloader(initial Config, logger *log.Logger, apply func(Config)) *ConfigReloader { + return &ConfigReloader{ + configPath: initial.ConfigPath, + logger: logger, + apply: apply, + current: initial, + stopCh: make(chan struct{}), + } +} + +func (r *ConfigReloader) Start() { + if strings.TrimSpace(r.configPath) == "" && strings.TrimSpace(r.current.AuthTokenDir) == "" { + return + } + + r.wg.Add(1) + go func() { + defer r.wg.Done() + r.run() + }() +} + +func (r *ConfigReloader) Stop() { + close(r.stopCh) + r.wg.Wait() +} + +func (r *ConfigReloader) run() { + watcher, err := fsnotify.NewWatcher() + if err != nil { + r.logger.Printf("config reloader disabled: %v", err) + return + } + defer watcher.Close() + + r.watchCurrentPaths(watcher) + + ticker := time.NewTicker(configReloadPollInterval) + defer ticker.Stop() + + for { + select { + case event := <-watcher.Events: + if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) || event.Has(fsnotify.Remove) || event.Has(fsnotify.Rename) { + r.reloadIfChanged(watcher, event.Name) + } + case err := <-watcher.Errors: + if err != nil { + r.logger.Printf("config watcher error: %v", err) + } + case <-ticker.C: + r.reloadIfChanged(watcher, "") + case <-r.stopCh: + return + } + } +} + +func (r *ConfigReloader) watchCurrentPaths(watcher *fsnotify.Watcher) { + r.mu.Lock() + cfg := r.current + r.mu.Unlock() + + seen := make(map[string]struct{}) + for _, path := range []string{cfg.ConfigPath, cfg.AuthTokenDir} { + path = strings.TrimSpace(path) + if path == "" { + continue + } + + target := path + info, err := os.Stat(path) + if err == nil && !info.IsDir() { + target = filepath.Dir(path) + } + target, err = filepath.Abs(target) + if err != nil { + continue + } + if _, exists := seen[target]; exists { + continue + } + if err := watcher.Add(target); err != nil { + r.logger.Printf("config watcher add %s failed: %v", target, err) + continue + } + seen[target] = struct{}{} + } +} + +func (r *ConfigReloader) reloadIfChanged(watcher *fsnotify.Watcher, changedPath string) { + r.mu.Lock() + current := r.current + r.mu.Unlock() + + if !r.isRelevantChange(current, changedPath) { + return + } + + cfg, err := loadConfig(current.ConfigPath) + if err != nil { + r.logger.Printf("config reload failed: %v", err) + return + } + if configSignature(cfg) == configSignature(current) { + return + } + + r.mu.Lock() + r.current = cfg + r.mu.Unlock() + r.watchCurrentPaths(watcher) + r.apply(cfg) + r.logger.Printf("config reloaded from %s (%d auth tokens, %d api keys)", displayConfigPath(cfg.ConfigPath), len(cfg.AuthTokens), len(cfg.APIKeys)) +} + +func (r *ConfigReloader) isRelevantChange(cfg Config, changedPath string) bool { + changedPath = strings.TrimSpace(changedPath) + if changedPath == "" { + return true + } + + changedAbs, err := filepath.Abs(changedPath) + if err != nil { + return true + } + + if cfg.ConfigPath != "" { + configAbs, err := filepath.Abs(cfg.ConfigPath) + if err == nil && strings.EqualFold(configAbs, changedAbs) { + return true + } + } + + if cfg.AuthTokenDir != "" { + tokenDirAbs, err := filepath.Abs(cfg.AuthTokenDir) + if err == nil { + rel, relErr := filepath.Rel(tokenDirAbs, changedAbs) + if relErr == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return true + } + } + } + + return false +} + +func configSignature(cfg Config) string { + return strings.Join([]string{ + cfg.ListenAddr, + cfg.UpstreamBaseURL, + cfg.AuthTokenDir, + cfg.RotationInterval.String(), + cfg.RequestTimeout.String(), + cfg.HTTPProxy, + strings.Join(cfg.AuthTokens, ","), + strings.Join(cfg.APIKeys, ","), + }, "|") +} + +func displayConfigPath(path string) string { + if strings.TrimSpace(path) == "" { + return "environment" + } + return path +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..90f0eac --- /dev/null +++ b/config_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfigSupportsYAMLAndTokenDir(t *testing.T) { + tempDir := t.TempDir() + tokenDir := filepath.Join(tempDir, "tokens.d") + if err := os.MkdirAll(tokenDir, 0o755); err != nil { + t.Fatalf("mkdir token dir: %v", err) + } + + if err := os.WriteFile(filepath.Join(tokenDir, "plain.txt"), []byte("token-from-text\n"), 0o644); err != nil { + t.Fatalf("write plain token: %v", err) + } + if err := os.WriteFile(filepath.Join(tokenDir, "json.json"), []byte(`{"authToken":"token-from-json"}`), 0o644); err != nil { + t.Fatalf("write json token: %v", err) + } + if err := os.WriteFile(filepath.Join(tokenDir, "yaml.yaml"), []byte("default:\n authToken: token-from-yaml\n"), 0o644); err != nil { + t.Fatalf("write yaml token: %v", err) + } + + configPath := filepath.Join(tempDir, "config.yaml") + configBody := []byte("" + + "LISTEN_ADDR: \":18080\"\n" + + "UPSTREAM_BASE_URL: \"https://codebuff.com\"\n" + + "AUTH_TOKENS:\n" + + " - inline-token\n" + + "AUTH_TOKEN_DIR: \"tokens.d\"\n" + + "ROTATION_INTERVAL: \"2h\"\n" + + "REQUEST_TIMEOUT: \"45s\"\n" + + "API_KEYS:\n" + + " - key-1\n") + if err := os.WriteFile(configPath, configBody, 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := loadConfig(configPath) + if err != nil { + t.Fatalf("loadConfig returned error: %v", err) + } + + if got := cfg.UpstreamBaseURL; got != "https://www.codebuff.com" { + t.Fatalf("expected normalized base url, got %q", got) + } + if got := cfg.AuthTokenDir; got != tokenDir { + t.Fatalf("expected absolute token dir %q, got %q", tokenDir, got) + } + if len(cfg.AuthTokens) != 4 { + t.Fatalf("expected 4 auth tokens, got %d (%v)", len(cfg.AuthTokens), cfg.AuthTokens) + } + if !containsString(cfg.AuthTokens, "inline-token") || !containsString(cfg.AuthTokens, "token-from-json") || !containsString(cfg.AuthTokens, "token-from-yaml") || !containsString(cfg.AuthTokens, "token-from-text") { + t.Fatalf("unexpected auth tokens: %v", cfg.AuthTokens) + } +} diff --git a/free_session.go b/free_session.go index 4f8bece..54ec97e 100644 --- a/free_session.go +++ b/free_session.go @@ -28,6 +28,7 @@ const ( type freeSessionResponse struct { Status string `json:"status"` InstanceID string `json:"instanceId"` + Model string `json:"model"` Position int `json:"position"` QueueDepth int `json:"queueDepth"` QueuedAt string `json:"queuedAt"` @@ -41,6 +42,7 @@ type freeSessionResponse struct { type cachedSession struct { status sessionStatus instanceID string + model string expiresAt time.Time position int queueDepth int @@ -48,17 +50,41 @@ type cachedSession struct { retryAfter time.Duration } -func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { +type modelSwitchError struct { + CurrentModel string + TargetModel string + RetryAfter time.Duration +} + +func (e *modelSwitchError) Error() string { + if e == nil { + return "session switch in progress" + } + if e.CurrentModel == "" || e.TargetModel == "" { + return "session switch in progress" + } + return fmt.Sprintf("token is switching from %s to %s", e.CurrentModel, e.TargetModel) +} + +func (p *tokenPool) ensureSession(ctx context.Context, model string) (string, error) { + model = strings.TrimSpace(model) for { p.mu.Lock() - if instanceID, ready := p.readySessionLocked(time.Now()); ready { + if instanceID, ready := p.readySessionLocked(time.Now(), model); ready { p.mu.Unlock() return instanceID, nil } - if waitingErr := waitingRoomErrorFromSession(p.name, p.session, time.Now()); waitingErr != nil { + if waitingErr := waitingRoomErrorFromSession(p.name, p.session, time.Now()); waitingErr != nil && p.sessionMatchesModelLocked(model) { p.mu.Unlock() return "", waitingErr } + if p.session != nil && !p.sessionMatchesModelLocked(model) { + p.mu.Unlock() + if err := p.prepareModel(ctx, model); err != nil { + return "", err + } + continue + } if ch := p.sessionRefreshCh; ch != nil { p.mu.Unlock() select { @@ -72,7 +98,7 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { p.sessionRefreshCh = ch p.mu.Unlock() - session, instanceID, err := p.refreshSession(ctx) + session, instanceID, err := p.refreshSession(ctx, model) p.mu.Lock() if session != nil { @@ -81,6 +107,9 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { if err != nil { p.session = nil p.lastError = err.Error() + if isBannedErrorMessage(err.Error()) { + p.disabled = true + } } else if waitingErr := waitingRoomErrorFromSession(p.name, session, time.Now()); waitingErr != nil { p.lastError = waitingErr.Error() } else { @@ -99,10 +128,13 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) { } } -func (p *tokenPool) readySessionLocked(now time.Time) (string, bool) { +func (p *tokenPool) readySessionLocked(now time.Time, model string) (string, bool) { if p.session == nil { return "", false } + if !p.sessionMatchesModelLocked(model) { + return "", false + } switch p.session.status { case sessionStatusDisabled: return "", true @@ -117,7 +149,8 @@ func (p *tokenPool) readySessionLocked(now time.Time) (string, bool) { return "", false } -func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, error) { +func (p *tokenPool) refreshSession(ctx context.Context, model string) (*cachedSession, string, error) { + model = strings.TrimSpace(model) p.mu.Lock() current := p.session p.mu.Unlock() @@ -132,7 +165,7 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, return nil, "", fmt.Errorf("poll free session: %w", err) } } else { - state, err = p.client.CreateOrRefreshSession(ctx, p.token) + state, err = p.client.CreateOrRefreshSession(ctx, p.token, model) if err != nil { return nil, "", fmt.Errorf("start free session: %w", err) } @@ -141,7 +174,7 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, for { switch sessionStatus(strings.TrimSpace(state.Status)) { case sessionStatusDisabled: - return &cachedSession{status: sessionStatusDisabled}, "", nil + return &cachedSession{status: sessionStatusDisabled, model: model}, "", nil case sessionStatusActive: instanceID := strings.TrimSpace(state.InstanceID) if instanceID == "" { @@ -154,6 +187,7 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, return &cachedSession{ status: sessionStatusActive, instanceID: instanceID, + model: firstNonEmptyTrimmedString(strings.TrimSpace(state.Model), model), expiresAt: expiresAt, }, instanceID, nil case sessionStatusQueued: @@ -166,13 +200,14 @@ func (p *tokenPool) refreshSession(ctx context.Context) (*cachedSession, string, return &cachedSession{ status: sessionStatusQueued, instanceID: instanceID, + model: firstNonEmptyTrimmedString(strings.TrimSpace(state.Model), model), position: maxInt(state.Position, 1), queueDepth: maxInt(state.QueueDepth, maxInt(state.Position, 1)), pollAt: time.Now().Add(delay), retryAfter: delay, }, "", nil case sessionStatusNone, sessionStatusEnded, sessionStatusSuperseded: - state, err = p.client.CreateOrRefreshSession(ctx, p.token) + state, err = p.client.CreateOrRefreshSession(ctx, p.token, model) if err != nil { return nil, "", fmt.Errorf("refresh free session: %w", err) } @@ -200,6 +235,101 @@ func (p *tokenPool) currentSessionInstanceID() string { return p.session.instanceID } +func (p *tokenPool) currentSessionModel() string { + p.mu.Lock() + defer p.mu.Unlock() + if p.session == nil { + return "" + } + return p.session.model +} + +func (p *tokenPool) sessionMatchesModelLocked(model string) bool { + if p.session == nil { + return false + } + model = strings.TrimSpace(model) + if model == "" || strings.TrimSpace(p.session.model) == "" { + return true + } + return p.session.model == model +} + +func (p *tokenPool) prepareModel(ctx context.Context, model string) error { + model = strings.TrimSpace(model) + if model == "" { + return nil + } + + p.mu.Lock() + currentModel := "" + if p.session != nil { + currentModel = strings.TrimSpace(p.session.model) + } + if currentModel == "" { + for _, run := range p.runs { + if strings.TrimSpace(run.model) != "" { + currentModel = strings.TrimSpace(run.model) + break + } + } + } + if currentModel == "" || currentModel == model { + p.mu.Unlock() + return nil + } + + for _, run := range p.runs { + if run.inflight > 0 { + p.mu.Unlock() + return &modelSwitchError{ + CurrentModel: currentModel, + TargetModel: model, + RetryAfter: 3 * time.Second, + } + } + } + for _, run := range p.draining { + if run.inflight > 0 { + p.mu.Unlock() + return &modelSwitchError{ + CurrentModel: currentModel, + TargetModel: model, + RetryAfter: 3 * time.Second, + } + } + } + + session := p.session + var allRuns []*managedRun + for _, run := range p.runs { + allRuns = append(allRuns, run) + } + allRuns = append(allRuns, p.draining...) + p.runs = make(map[string]*managedRun) + p.draining = nil + p.session = nil + p.lastError = "" + p.lastModelSwitch = time.Now() + p.mu.Unlock() + + var errs []string + for _, run := range allRuns { + if err := p.client.FinishRun(ctx, p.token, run.id, run.requestCount); err != nil { + errs = append(errs, err.Error()) + } + } + if session != nil && session.status != sessionStatusDisabled && session.instanceID != "" { + if err := p.client.EndSession(ctx, p.token); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("switch token from model %s to %s: %s", currentModel, model, strings.Join(errs, "; ")) + } + return nil +} + func waitingRoomErrorFromSession(token string, session *cachedSession, now time.Time) *waitingRoomError { if session == nil || session.status != sessionStatusQueued { return nil @@ -286,27 +416,31 @@ func (p *tokenPool) endSession(ctx context.Context) error { return nil } -func (c *UpstreamClient) CreateOrRefreshSession(ctx context.Context, authToken string) (freeSessionResponse, error) { - return c.doSessionRequest(ctx, http.MethodPost, authToken, "") +func (c *UpstreamClient) CreateOrRefreshSession(ctx context.Context, authToken, model string) (freeSessionResponse, error) { + return c.doSessionRequest(ctx, http.MethodPost, authToken, "", model) } func (c *UpstreamClient) GetSession(ctx context.Context, authToken, instanceID string) (freeSessionResponse, error) { - return c.doSessionRequest(ctx, http.MethodGet, authToken, instanceID) + return c.doSessionRequest(ctx, http.MethodGet, authToken, instanceID, "") } func (c *UpstreamClient) EndSession(ctx context.Context, authToken string) error { - requestURL, err := url.JoinPath(c.baseURL, "/api/v1/freebuff/session") + cfg := c.cfgStore.Current() + requestURL, err := url.JoinPath(cfg.UpstreamBaseURL, "/api/v1/freebuff/session") if err != nil { return fmt.Errorf("build free session url: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, requestURL, nil) + requestCtx, cancel := c.requestContext(ctx) + defer cancel() + + req, err := http.NewRequestWithContext(requestCtx, http.MethodDelete, requestURL, nil) if err != nil { return fmt.Errorf("create free session delete request: %w", err) } req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("User-Agent", cfg.UserAgent) resp, err := c.httpClient.Do(req) if err != nil { @@ -324,8 +458,9 @@ func (c *UpstreamClient) EndSession(ctx context.Context, authToken string) error return nil } -func (c *UpstreamClient) doSessionRequest(ctx context.Context, method, authToken, instanceID string) (freeSessionResponse, error) { - requestURL, err := url.JoinPath(c.baseURL, "/api/v1/freebuff/session") +func (c *UpstreamClient) doSessionRequest(ctx context.Context, method, authToken, instanceID, model string) (freeSessionResponse, error) { + cfg := c.cfgStore.Current() + requestURL, err := url.JoinPath(cfg.UpstreamBaseURL, "/api/v1/freebuff/session") if err != nil { return freeSessionResponse{}, fmt.Errorf("build free session url: %w", err) } @@ -335,15 +470,21 @@ func (c *UpstreamClient) doSessionRequest(ctx context.Context, method, authToken body = bytes.NewReader([]byte("{}")) } - req, err := http.NewRequestWithContext(ctx, method, requestURL, body) + requestCtx, cancel := c.requestContext(ctx) + defer cancel() + + req, err := http.NewRequestWithContext(requestCtx, method, requestURL, body) if err != nil { return freeSessionResponse{}, fmt.Errorf("create free session request: %w", err) } req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("User-Agent", cfg.UserAgent) if method == http.MethodPost { req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(model) != "" { + req.Header.Set("x-freebuff-model", strings.TrimSpace(model)) + } } if method == http.MethodGet && instanceID != "" { req.Header.Set("x-freebuff-instance-id", instanceID) @@ -412,3 +553,12 @@ func sleepWithContext(ctx context.Context, delay time.Duration) error { return nil } } + +func firstNonEmptyTrimmedString(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} diff --git a/go.mod b/go.mod index c21551e..ddad63b 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,13 @@ module github.com/Quorinex/Freebuff2API go 1.23.0 +require ( + github.com/fsnotify/fsnotify v1.9.0 + github.com/tiktoken-go/tokenizer v0.7.0 + gopkg.in/yaml.v3 v3.0.1 +) + require ( github.com/dlclark/regexp2 v1.11.5 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/tiktoken-go/tokenizer v0.7.0 // indirect + golang.org/x/sys v0.13.0 // indirect ) diff --git a/go.sum b/go.sum index 5a0f650..6896e09 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw= github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 6c3fd2f..280d1fe 100644 --- a/main.go +++ b/main.go @@ -13,19 +13,17 @@ import ( ) func main() { - configPath := flag.String("config", "", "path to a JSON config file (default: config.json if present)") + configPath := flag.String("config", "", "path to a YAML or JSON config file") flag.Parse() logger := log.New(os.Stdout, "[Freebuff2API] ", log.LstdFlags|log.Lmsgprefix) - // Auto-detect config.json in CWD when no flag is given - if *configPath == "" { - if _, err := os.Stat("config.json"); err == nil { - *configPath = "config.json" - } + resolvedConfigPath, err := resolveConfigPath(*configPath) + if err != nil { + logger.Fatalf("resolve config path: %v", err) } - cfg, err := loadConfig(*configPath) + cfg, err := loadConfig(resolvedConfigPath) if err != nil { logger.Fatalf("load config: %v", err) } @@ -36,7 +34,7 @@ func main() { transport.Proxy = http.ProxyURL(importURL) } httpClient := &http.Client{Transport: transport, Timeout: 15 * time.Second} - + registry := NewModelRegistry(httpClient, logger) registry.Start(context.Background()) defer registry.Stop() @@ -46,6 +44,10 @@ func main() { defer cancelRun() server.Start(runCtx) + reloader := NewConfigReloader(cfg, logger, server.ApplyConfig) + reloader.Start() + defer reloader.Stop() + httpServer := &http.Server{ Addr: cfg.ListenAddr, Handler: server.Handler(), diff --git a/models.go b/models.go index a1bd1b3..1abba4c 100644 --- a/models.go +++ b/models.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log" - "math/rand" "net/http" "regexp" "sort" @@ -202,7 +201,7 @@ func parseAllFreeModels(source string) map[string][]string { } // buildModelMapping creates the model→agent reverse mapping and deduplicated model list. -// When a model appears in multiple agents, one is chosen at random. +// When a model appears in multiple agents, pick the least-used agent to spread traffic. func buildModelMapping(agentModels map[string][]string) (map[string]string, []string) { modelAgents := make(map[string][]string) for agentID, models := range agentModels { @@ -213,10 +212,25 @@ func buildModelMapping(agentModels map[string][]string) (map[string]string, []st modelToAgent := make(map[string]string, len(modelAgents)) allModels := make([]string, 0, len(modelAgents)) - for model, agents := range modelAgents { - modelToAgent[model] = agents[rand.Intn(len(agents))] + for model := range modelAgents { allModels = append(allModels, model) } sort.Strings(allModels) + + agentUseCount := make(map[string]int, len(agentModels)) + for _, model := range allModels { + agents := append([]string(nil), modelAgents[model]...) + sort.Strings(agents) + chosen := agents[0] + bestCount := agentUseCount[chosen] + for _, agentID := range agents[1:] { + if count := agentUseCount[agentID]; count < bestCount { + chosen = agentID + bestCount = count + } + } + modelToAgent[model] = chosen + agentUseCount[chosen]++ + } return modelToAgent, allModels } diff --git a/reasoning.go b/reasoning.go new file mode 100644 index 0000000..eaf1003 --- /dev/null +++ b/reasoning.go @@ -0,0 +1,23 @@ +package main + +import "strings" + +func normalizeReasoningEffort(raw string) (string, bool) { + effort := strings.ToLower(strings.TrimSpace(raw)) + switch effort { + case "": + return "", false + case "none", "false", "disabled", "off": + return "none", true + case "minimal": + return "low", true + case "low", "medium", "high": + return effort, true + case "xhigh", "max": + return "high", true + case "auto", "true", "enabled", "on": + return "", false + default: + return "", false + } +} diff --git a/reasoning_test.go b/reasoning_test.go new file mode 100644 index 0000000..4a02d0f --- /dev/null +++ b/reasoning_test.go @@ -0,0 +1,64 @@ +package main + +import "testing" + +func TestNormalizeReasoningEffort(t *testing.T) { + tests := []struct { + name string + input string + want string + ok bool + }{ + {name: "none", input: "none", want: "none", ok: true}, + {name: "minimal downgraded", input: "minimal", want: "low", ok: true}, + {name: "xhigh downgraded", input: "xhigh", want: "high", ok: true}, + {name: "max downgraded", input: "max", want: "high", ok: true}, + {name: "auto omitted", input: "auto", want: "", ok: false}, + {name: "enabled omitted", input: "enabled", want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := normalizeReasoningEffort(tt.input) + if ok != tt.ok || got != tt.want { + t.Fatalf("normalizeReasoningEffort(%q) = (%q, %v), want (%q, %v)", tt.input, got, ok, tt.want, tt.ok) + } + }) + } +} + +func TestConvertClaudeMessagesRequestToOpenAINormalizesReasoningEffort(t *testing.T) { + body := []byte(`{ + "model": "z-ai/glm-5.1", + "thinking": {"type":"adaptive"}, + "output_config": {"effort":"max"}, + "messages": [{"role":"user","content":"hello"}] + }`) + + payload, _, _, err := convertClaudeMessagesRequestToOpenAI(body) + if err != nil { + t.Fatalf("convertClaudeMessagesRequestToOpenAI returned error: %v", err) + } + + if got := payload["reasoning_effort"]; got != "high" { + t.Fatalf("expected reasoning_effort=high, got %#v", got) + } +} + +func TestConvertResponsesCreateRequestToOpenAINormalizesReasoningEffort(t *testing.T) { + store := newResponseStore() + body := []byte(`{ + "model": "z-ai/glm-5.1", + "reasoning": {"effort":"xhigh"}, + "input": [{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}] + }`) + + payload, _, _, _, err := convertResponsesCreateRequestToOpenAI(body, store) + if err != nil { + t.Fatalf("convertResponsesCreateRequestToOpenAI returned error: %v", err) + } + + if got := payload["reasoning_effort"]; got != "high" { + t.Fatalf("expected reasoning_effort=high, got %#v", got) + } +} diff --git a/responses.go b/responses.go new file mode 100644 index 0000000..4074721 --- /dev/null +++ b/responses.go @@ -0,0 +1,1236 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +type responseStore struct { + mu sync.RWMutex + items map[string]storedResponse +} + +type storedResponse struct { + Conversation []map[string]any +} + +func newResponseStore() *responseStore { + return &responseStore{ + items: make(map[string]storedResponse), + } +} + +func (s *responseStore) GetConversation(id string) ([]map[string]any, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.items[strings.TrimSpace(id)] + if !ok { + return nil, false + } + return cloneResponseItems(entry.Conversation), true +} + +func (s *responseStore) Put(id string, conversation []map[string]any) { + id = strings.TrimSpace(id) + if id == "" { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + s.items[id] = storedResponse{Conversation: cloneResponseItems(conversation)} +} + +func convertResponsesCreateRequestToOpenAI(body []byte, store *responseStore) (map[string]any, string, bool, []map[string]any, error) { + var root map[string]any + if err := json.Unmarshal(body, &root); err != nil { + return nil, "", false, nil, fmt.Errorf("request body must be valid JSON") + } + + modelName := strings.TrimSpace(stringValue(root["model"])) + if modelName == "" { + return nil, "", false, nil, fmt.Errorf("model is required") + } + + conversation := make([]map[string]any, 0, 16) + if previousID := strings.TrimSpace(stringValue(root["previous_response_id"])); previousID != "" { + previousConversation, ok := store.GetConversation(previousID) + if !ok { + return nil, "", false, nil, fmt.Errorf("unknown previous_response_id %q", previousID) + } + conversation = append(conversation, previousConversation...) + } + + if instructions := normalizeResponsesInstructions(root["instructions"]); instructions != nil { + conversation = append(conversation, instructions) + } + + inputItems, err := normalizeResponsesInputItems(root["input"]) + if err != nil { + return nil, "", false, nil, err + } + if len(inputItems) == 0 { + return nil, "", false, nil, fmt.Errorf("input is required") + } + conversation = append(conversation, inputItems...) + + messages := make([]any, 0, len(conversation)+2) + for _, item := range conversation { + var appendErr error + messages, appendErr = appendResponsesItemToOpenAIMessages(messages, item) + if appendErr != nil { + return nil, "", false, nil, appendErr + } + } + + payload := map[string]any{ + "model": modelName, + "messages": messages, + } + + stream := boolValue(root["stream"]) + payload["stream"] = stream + + if maxOutputTokens, ok := intValue(root["max_output_tokens"]); ok && maxOutputTokens > 0 { + payload["max_tokens"] = maxOutputTokens + } else if maxTokens, ok := intValue(root["max_completion_tokens"]); ok && maxTokens > 0 { + payload["max_tokens"] = maxTokens + } else if maxTokens, ok := intValue(root["max_tokens"]); ok && maxTokens > 0 { + payload["max_tokens"] = maxTokens + } + + if temperature, ok := floatValue(root["temperature"]); ok { + payload["temperature"] = temperature + } + if topP, ok := floatValue(root["top_p"]); ok { + payload["top_p"] = topP + } + if userValue := strings.TrimSpace(stringValue(root["user"])); userValue != "" { + payload["user"] = userValue + } + + if reasoning := mapValue(root["reasoning"]); reasoning != nil { + if effort := strings.TrimSpace(stringValue(reasoning["effort"])); effort != "" { + if normalized, ok := normalizeReasoningEffort(effort); ok { + payload["reasoning_effort"] = normalized + } + } + } + + if tools := convertResponsesToolsToOpenAI(root["tools"]); len(tools) > 0 { + payload["tools"] = tools + } + if toolChoice, ok := convertResponsesToolChoiceToOpenAI(root["tool_choice"]); ok { + payload["tool_choice"] = toolChoice + } + + return payload, modelName, stream, conversation, nil +} + +func normalizeResponsesInstructions(value any) map[string]any { + text := strings.TrimSpace(stringValue(value)) + if text == "" { + return nil + } + return map[string]any{ + "type": "message", + "role": "system", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": text, + }, + }, + } +} + +func normalizeResponsesInputItems(value any) ([]map[string]any, error) { + switch typed := value.(type) { + case nil: + return nil, nil + case string: + text := strings.TrimSpace(typed) + if text == "" { + return nil, nil + } + return []map[string]any{ + { + "type": "message", + "role": "user", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": text, + }, + }, + }, + }, nil + case []any: + items := make([]map[string]any, 0, len(typed)) + for _, rawItem := range typed { + item, err := normalizeResponsesInputItem(rawItem) + if err != nil { + return nil, err + } + if item != nil { + items = append(items, item) + } + } + return items, nil + case map[string]any: + item, err := normalizeResponsesInputItem(typed) + if err != nil { + return nil, err + } + if item == nil { + return nil, nil + } + return []map[string]any{item}, nil + default: + text := strings.TrimSpace(fmt.Sprint(value)) + if text == "" { + return nil, nil + } + return []map[string]any{ + { + "type": "message", + "role": "user", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": text, + }, + }, + }, + }, nil + } +} + +func normalizeResponsesInputItem(value any) (map[string]any, error) { + item := mapValue(value) + if item == nil { + if text, ok := value.(string); ok { + text = strings.TrimSpace(text) + if text == "" { + return nil, nil + } + return map[string]any{ + "type": "message", + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": text}, + }, + }, nil + } + return nil, fmt.Errorf("input items must be objects") + } + + itemType := strings.TrimSpace(stringValue(item["type"])) + if itemType == "" && strings.TrimSpace(stringValue(item["role"])) != "" { + itemType = "message" + } + + switch itemType { + case "message": + role := strings.ToLower(strings.TrimSpace(stringValue(item["role"]))) + if role == "" { + role = "user" + } + if role != "user" && role != "assistant" && role != "system" && role != "developer" { + return nil, fmt.Errorf("unsupported input role %q", role) + } + content := normalizeResponsesContent(item["content"], role) + if len(content) == 0 { + return nil, nil + } + return map[string]any{ + "type": "message", + "role": role, + "content": content, + }, nil + + case "function_call": + name := strings.TrimSpace(stringValue(item["name"])) + if name == "" { + return nil, fmt.Errorf("function_call name is required") + } + callID := strings.TrimSpace(stringValue(item["call_id"])) + itemID := strings.TrimSpace(stringValue(item["id"])) + if callID == "" && itemID != "" { + callID = responseCallIDFromValue(itemID) + } + if callID == "" { + callID = responseCallIDFromValue(name) + } + if itemID == "" { + itemID = responseFunctionIDFromCallID(callID) + } + return map[string]any{ + "type": "function_call", + "id": itemID, + "call_id": callID, + "name": name, + "arguments": normalizeJSONString(item["arguments"]), + }, nil + + case "function_call_output": + callID := strings.TrimSpace(stringValue(item["call_id"])) + if callID == "" { + return nil, fmt.Errorf("function_call_output call_id is required") + } + return map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": normalizeResponseOutputValue(item["output"]), + }, nil + + default: + return nil, fmt.Errorf("unsupported input item type %q", itemType) + } +} + +func normalizeResponsesContent(value any, role string) []any { + textType := "input_text" + if role == "assistant" { + textType = "output_text" + } + + switch typed := value.(type) { + case nil: + return nil + case string: + text := strings.TrimSpace(typed) + if text == "" { + return nil + } + return []any{map[string]any{"type": textType, "text": text}} + case []any: + parts := make([]any, 0, len(typed)) + for _, rawPart := range typed { + part := mapValue(rawPart) + if part == nil { + text := strings.TrimSpace(fmt.Sprint(rawPart)) + if text == "" { + continue + } + parts = append(parts, map[string]any{"type": textType, "text": text}) + continue + } + + partType := strings.ToLower(strings.TrimSpace(stringValue(part["type"]))) + switch partType { + case "", "text", "input_text", "output_text": + text := stringValue(part["text"]) + if strings.TrimSpace(text) == "" { + continue + } + parts = append(parts, map[string]any{"type": textType, "text": text}) + case "input_image", "image": + if imageURL := firstResponseImageURL(part); imageURL != "" && role != "assistant" { + parts = append(parts, map[string]any{"type": "input_image", "image_url": imageURL}) + } + } + } + return parts + default: + text := strings.TrimSpace(fmt.Sprint(value)) + if text == "" { + return nil + } + return []any{map[string]any{"type": textType, "text": text}} + } +} + +func firstResponseImageURL(part map[string]any) string { + if url := strings.TrimSpace(stringValue(part["image_url"])); url != "" { + return url + } + imageURL := mapValue(part["image_url"]) + if imageURL != nil { + if url := strings.TrimSpace(stringValue(imageURL["url"])); url != "" { + return url + } + } + if source := mapValue(part["source"]); source != nil { + if strings.EqualFold(stringValue(source["type"]), "url") { + return strings.TrimSpace(stringValue(source["url"])) + } + if strings.EqualFold(stringValue(source["type"]), "base64") { + mediaType := strings.TrimSpace(stringValue(source["media_type"])) + data := strings.TrimSpace(stringValue(source["data"])) + if mediaType != "" && data != "" { + return "data:" + mediaType + ";base64," + data + } + } + } + return "" +} + +func appendResponsesItemToOpenAIMessages(messages []any, item map[string]any) ([]any, error) { + switch strings.TrimSpace(stringValue(item["type"])) { + case "message": + role := strings.ToLower(strings.TrimSpace(stringValue(item["role"]))) + if role == "developer" { + role = "system" + } + contentParts := responsesContentPartsToOpenAI(item["content"], role) + if len(contentParts) == 0 { + return messages, nil + } + messages = append(messages, map[string]any{ + "role": role, + "content": normalizeOpenAIContent(contentParts), + }) + return messages, nil + + case "function_call": + name := strings.TrimSpace(stringValue(item["name"])) + if name == "" { + return messages, fmt.Errorf("function_call name is required") + } + callID := strings.TrimSpace(stringValue(item["call_id"])) + if callID == "" { + callID = responseCallIDFromValue(name) + } + messages = append(messages, map[string]any{ + "role": "assistant", + "content": "", + "tool_calls": []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": normalizeJSONString(item["arguments"]), + }, + }, + }, + }) + return messages, nil + + case "function_call_output": + callID := strings.TrimSpace(stringValue(item["call_id"])) + if callID == "" { + return messages, fmt.Errorf("function_call_output call_id is required") + } + messages = append(messages, map[string]any{ + "role": "tool", + "tool_call_id": callID, + "content": normalizeResponseOutputValue(item["output"]), + }) + return messages, nil + } + + return messages, nil +} + +func responsesContentPartsToOpenAI(value any, role string) []any { + rawParts := sliceValue(value) + if len(rawParts) == 0 { + return nil + } + + contentParts := make([]any, 0, len(rawParts)) + for _, rawPart := range rawParts { + part := mapValue(rawPart) + if part == nil { + continue + } + + switch strings.ToLower(strings.TrimSpace(stringValue(part["type"]))) { + case "input_text", "output_text", "text": + text := stringValue(part["text"]) + if strings.TrimSpace(text) == "" { + continue + } + contentParts = append(contentParts, map[string]any{ + "type": "text", + "text": text, + }) + case "input_image": + if role == "assistant" { + continue + } + imageURL := strings.TrimSpace(stringValue(part["image_url"])) + if imageURL == "" { + continue + } + contentParts = append(contentParts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": imageURL, + }, + }) + } + } + return contentParts +} + +func convertResponsesToolsToOpenAI(value any) []any { + rawTools := sliceValue(value) + if len(rawTools) == 0 { + return nil + } + + tools := make([]any, 0, len(rawTools)) + for _, rawTool := range rawTools { + tool := mapValue(rawTool) + if tool == nil { + continue + } + + toolType := strings.ToLower(strings.TrimSpace(stringValue(tool["type"]))) + name := strings.TrimSpace(stringValue(tool["name"])) + if name == "" && toolType == "function" { + name = strings.TrimSpace(stringValue(tool["name"])) + } + if name == "" { + name = toolType + } + if name == "" { + continue + } + + parameters := mapValue(tool["parameters"]) + if parameters == nil { + parameters = mapValue(tool["input_schema"]) + } + if parameters == nil { + parameters = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + + tools = append(tools, map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + "description": strings.TrimSpace(stringValue(tool["description"])), + "parameters": parameters, + }, + }) + } + return tools +} + +func convertResponsesToolChoiceToOpenAI(value any) (any, bool) { + switch typed := value.(type) { + case string: + switch strings.ToLower(strings.TrimSpace(typed)) { + case "", "auto": + return "auto", true + case "required": + return "required", true + case "none": + return "none", true + default: + return typed, true + } + case map[string]any: + choiceType := strings.ToLower(strings.TrimSpace(stringValue(typed["type"]))) + switch choiceType { + case "", "auto": + return "auto", true + case "required": + return "required", true + case "none": + return "none", true + default: + name := strings.TrimSpace(stringValue(typed["name"])) + if name == "" { + name = choiceType + } + if name == "" { + return nil, false + } + return map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + }, true + } + default: + return nil, false + } +} + +func writeResponsesSuccessResponse(w http.ResponseWriter, resp *http.Response, requestedModel string, stream bool, conversation []map[string]any, store *responseStore) error { + if stream { + return writeResponsesStream(w, resp, requestedModel, conversation, store) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read upstream response: %w", err) + } + + converted, responseID, storedConversation, err := convertOpenAINonStreamResponseToResponses(body, requestedModel, conversation) + if err != nil { + return err + } + if responseID != "" { + store.Put(responseID, storedConversation) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resp.StatusCode) + _, err = w.Write(converted) + return err +} + +func convertOpenAINonStreamResponseToResponses(body []byte, requestedModel string, conversation []map[string]any) ([]byte, string, []map[string]any, error) { + var response openAIChatCompletion + if err := json.Unmarshal(body, &response); err != nil { + return nil, "", nil, fmt.Errorf("decode upstream response: %w", err) + } + + responseObject, conversationItems := buildResponsesObjectFromOpenAI(response, requestedModel) + storedConversation := append(cloneResponseItems(conversation), conversationItems...) + encoded, err := json.Marshal(responseObject) + if err != nil { + return nil, "", nil, fmt.Errorf("encode responses response: %w", err) + } + return encoded, strings.TrimSpace(stringValue(responseObject["id"])), storedConversation, nil +} + +func buildResponsesObjectFromOpenAI(response openAIChatCompletion, requestedModel string) (map[string]any, []map[string]any) { + modelName := strings.TrimSpace(response.Model) + if modelName == "" { + modelName = requestedModel + } + + responseID := normalizeResponseID(response.ID) + outputItems := make([]any, 0, 4) + storedItems := make([]map[string]any, 0, 4) + outputTexts := make([]string, 0, 2) + finishReason := "" + var usage map[string]any + + if len(response.Choices) > 0 { + choice := response.Choices[0] + finishReason = strings.TrimSpace(choice.FinishReason) + + messageParts := make([]any, 0, 4) + for _, block := range convertOpenAIContentToClaudeBlocks(choice.Message.Content) { + blockMap := mapValue(block) + if blockMap == nil || !strings.EqualFold(stringValue(blockMap["type"]), "text") { + continue + } + text := stringValue(blockMap["text"]) + if strings.TrimSpace(text) == "" { + continue + } + messageParts = append(messageParts, map[string]any{ + "type": "output_text", + "text": text, + }) + outputTexts = append(outputTexts, text) + } + + if len(messageParts) > 0 || len(choice.Message.ToolCalls) == 0 { + messageItem := map[string]any{ + "id": normalizeResponseMessageID(responseID), + "type": "message", + "status": "completed", + "role": "assistant", + "content": messageParts, + } + outputItems = append(outputItems, messageItem) + storedItems = append(storedItems, map[string]any{ + "type": "message", + "role": "assistant", + "content": cloneGenericSlice(messageParts), + }) + } + + for _, toolCall := range choice.Message.ToolCalls { + callID := responseCallIDFromValue(toolCall.ID) + functionID := responseFunctionIDFromCallID(callID) + toolItem := map[string]any{ + "id": functionID, + "type": "function_call", + "status": "completed", + "call_id": callID, + "name": toolCall.Function.Name, + "arguments": normalizeJSONString(toolCall.Function.Arguments), + } + outputItems = append(outputItems, toolItem) + storedItems = append(storedItems, map[string]any{ + "type": "function_call", + "id": functionID, + "call_id": callID, + "name": toolCall.Function.Name, + "arguments": normalizeJSONString(toolCall.Function.Arguments), + }) + } + } + + if response.Usage != nil { + usage = mapOpenAIUsageToResponses(response.Usage) + } else { + usage = map[string]any{ + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + } + + responseObject := map[string]any{ + "id": responseID, + "object": "response", + "status": "completed", + "created_at": time.Now().Unix(), + "model": modelName, + "output": outputItems, + "output_text": strings.Join(outputTexts, ""), + "usage": usage, + "finish_reason": finishReason, + } + + return responseObject, storedItems +} + +func mapOpenAIUsageToResponses(usage *openAIUsage) map[string]any { + inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(usage) + totalTokens := inputTokens + outputTokens + cachedTokens + if usage != nil && usage.TotalTokens > 0 { + totalTokens = usage.TotalTokens + } + out := map[string]any{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + "total_tokens": totalTokens, + } + if cachedTokens > 0 { + out["input_tokens_details"] = map[string]any{ + "cached_tokens": cachedTokens, + } + } + return out +} + +type responseStreamState struct { + responseID string + model string + message *responseOutputMessageState + toolCalls map[int]*responseOutputToolState + outputOrder []responseOutputEntry + usage *openAIUsage + sequence int +} + +type responseOutputEntry struct { + kind string + index int +} + +type responseOutputMessageState struct { + ID string + OutputIndex int + Text strings.Builder + Started bool +} + +type responseOutputToolState struct { + Index int + ID string + CallID string + Name string + Arguments strings.Builder + OutputIndex int + Started bool +} + +func writeResponsesStream(w http.ResponseWriter, resp *http.Response, requestedModel string, conversation []map[string]any, store *responseStore) error { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(resp.StatusCode) + + flusher, _ := w.(http.Flusher) + state := &responseStreamState{ + responseID: normalizeResponseID(""), + model: requestedModel, + toolCalls: make(map[int]*responseOutputToolState), + outputOrder: make([]responseOutputEntry, 0, 4), + } + + if err := writeResponsesSSEEvent(w, "response.created", map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": state.responseID, + "object": "response", + "status": "in_progress", + "created_at": time.Now().Unix(), + "model": state.model, + "output": []any{}, + }, + "sequence_number": state.nextSequence(), + }); err != nil { + return err + } + if err := writeResponsesSSEEvent(w, "response.in_progress", map[string]any{ + "type": "response.in_progress", + "response": map[string]any{ + "id": state.responseID, + "object": "response", + "status": "in_progress", + "created_at": time.Now().Unix(), + "model": state.model, + "output": []any{}, + }, + "sequence_number": state.nextSequence(), + }); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + + reader := bufio.NewReader(resp.Body) + for { + payload, err := readNextSSEDataBlock(reader) + if err != nil { + if err == io.EOF { + break + } + return err + } + + trimmed := bytes.TrimSpace(payload) + if len(trimmed) == 0 { + continue + } + if bytes.Equal(trimmed, []byte("[DONE]")) { + break + } + + var chunk openAIChatCompletion + if err := json.Unmarshal(trimmed, &chunk); err != nil { + return fmt.Errorf("decode upstream stream chunk: %w", err) + } + + if strings.TrimSpace(chunk.Model) != "" { + state.model = chunk.Model + } + if chunk.Usage != nil { + state.usage = chunk.Usage + } + + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + if err := state.emitTextDelta(w, choice.Delta.Content); err != nil { + return err + } + } + for _, toolCall := range choice.Delta.ToolCalls { + if err := state.emitToolDelta(w, toolCall); err != nil { + return err + } + } + } + + if flusher != nil { + flusher.Flush() + } + } + + outputItems, storedItems, outputText, err := state.finishStream(w) + if err != nil { + return err + } + + responseObject := map[string]any{ + "id": state.responseID, + "object": "response", + "status": "completed", + "created_at": time.Now().Unix(), + "model": state.model, + "output": outputItems, + "output_text": outputText, + "usage": mapOpenAIUsageToResponses(state.usage), + } + if err := writeResponsesSSEEvent(w, "response.completed", map[string]any{ + "type": "response.completed", + "response": responseObject, + "sequence_number": state.nextSequence(), + }); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + + store.Put(state.responseID, append(cloneResponseItems(conversation), storedItems...)) + return nil +} + +func (s *responseStreamState) emitTextDelta(w http.ResponseWriter, delta string) error { + if s.message == nil { + s.message = &responseOutputMessageState{ + ID: normalizeResponseMessageID(s.responseID), + OutputIndex: len(s.outputOrder), + Started: true, + } + s.outputOrder = append(s.outputOrder, responseOutputEntry{kind: "message", index: -1}) + if err := writeResponsesSSEEvent(w, "response.output_item.added", map[string]any{ + "type": "response.output_item.added", + "item": map[string]any{ + "id": s.message.ID, + "type": "message", + "status": "in_progress", + "content": []any{}, + "role": "assistant", + }, + "output_index": s.message.OutputIndex, + "sequence_number": s.nextSequence(), + }); err != nil { + return err + } + if err := writeResponsesSSEEvent(w, "response.content_part.added", map[string]any{ + "type": "response.content_part.added", + "part": map[string]any{ + "type": "output_text", + "text": "", + }, + "item_id": s.message.ID, + "output_index": s.message.OutputIndex, + "content_index": 0, + "sequence_number": s.nextSequence(), + }); err != nil { + return err + } + } + + s.message.Text.WriteString(delta) + return writeResponsesSSEEvent(w, "response.output_text.delta", map[string]any{ + "type": "response.output_text.delta", + "delta": delta, + "item_id": s.message.ID, + "output_index": s.message.OutputIndex, + "content_index": 0, + "sequence_number": s.nextSequence(), + }) +} + +func (s *responseStreamState) emitToolDelta(w http.ResponseWriter, toolCall openAIStreamToolCall) error { + state, ok := s.toolCalls[toolCall.Index] + if !ok { + callID := responseCallIDFromValue(toolCall.ID) + state = &responseOutputToolState{ + Index: toolCall.Index, + ID: responseFunctionIDFromCallID(callID), + CallID: callID, + OutputIndex: len(s.outputOrder), + Started: true, + } + s.toolCalls[toolCall.Index] = state + s.outputOrder = append(s.outputOrder, responseOutputEntry{kind: "tool", index: toolCall.Index}) + } + + if strings.TrimSpace(toolCall.ID) != "" { + state.CallID = responseCallIDFromValue(toolCall.ID) + state.ID = responseFunctionIDFromCallID(state.CallID) + } + if strings.TrimSpace(toolCall.Function.Name) != "" { + state.Name = toolCall.Function.Name + } + + if state.Started { + state.Started = false + if err := writeResponsesSSEEvent(w, "response.output_item.added", map[string]any{ + "type": "response.output_item.added", + "item": map[string]any{ + "id": state.ID, + "type": "function_call", + "status": "in_progress", + "call_id": state.CallID, + "name": state.Name, + "arguments": "", + }, + "output_index": state.OutputIndex, + "sequence_number": s.nextSequence(), + }); err != nil { + return err + } + } + + if toolCall.Function.Arguments == "" { + return nil + } + state.Arguments.WriteString(toolCall.Function.Arguments) + return writeResponsesSSEEvent(w, "response.function_call_arguments.delta", map[string]any{ + "type": "response.function_call_arguments.delta", + "delta": toolCall.Function.Arguments, + "item_id": state.ID, + "output_index": state.OutputIndex, + "sequence_number": s.nextSequence(), + }) +} + +func (s *responseStreamState) finishStream(w http.ResponseWriter) ([]any, []map[string]any, string, error) { + outputItems := make([]any, 0, len(s.outputOrder)) + storedItems := make([]map[string]any, 0, len(s.outputOrder)) + outputTexts := make([]string, 0, 2) + + for _, entry := range s.outputOrder { + switch entry.kind { + case "message": + if s.message == nil { + continue + } + text := s.message.Text.String() + item := map[string]any{ + "id": s.message.ID, + "type": "message", + "status": "completed", + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "output_text", + "text": text, + }, + }, + } + if strings.TrimSpace(text) == "" { + item["content"] = []any{} + } else { + outputTexts = append(outputTexts, text) + } + if err := writeResponsesSSEEvent(w, "response.output_item.done", map[string]any{ + "type": "response.output_item.done", + "item": item, + "output_index": s.message.OutputIndex, + "sequence_number": s.nextSequence(), + }); err != nil { + return nil, nil, "", err + } + outputItems = append(outputItems, item) + storedItems = append(storedItems, map[string]any{ + "type": "message", + "role": "assistant", + "content": cloneGenericSlice(item["content"].([]any)), + }) + + case "tool": + toolState := s.toolCalls[entry.index] + if toolState == nil { + continue + } + item := map[string]any{ + "id": toolState.ID, + "type": "function_call", + "status": "completed", + "call_id": toolState.CallID, + "name": toolState.Name, + "arguments": toolState.Arguments.String(), + } + if err := writeResponsesSSEEvent(w, "response.output_item.done", map[string]any{ + "type": "response.output_item.done", + "item": item, + "output_index": toolState.OutputIndex, + "sequence_number": s.nextSequence(), + }); err != nil { + return nil, nil, "", err + } + outputItems = append(outputItems, item) + storedItems = append(storedItems, map[string]any{ + "type": "function_call", + "id": toolState.ID, + "call_id": toolState.CallID, + "name": toolState.Name, + "arguments": toolState.Arguments.String(), + }) + } + } + + return outputItems, storedItems, strings.Join(outputTexts, ""), nil +} + +func (s *responseStreamState) nextSequence() int { + current := s.sequence + s.sequence++ + return current +} + +func writeResponsesSSEEvent(w http.ResponseWriter, eventName string, payload any) error { + encoded, err := json.Marshal(payload) + if err != nil { + return err + } + if _, err := io.WriteString(w, "event: "+eventName+"\n"); err != nil { + return err + } + if _, err := io.WriteString(w, "data: "); err != nil { + return err + } + if _, err := w.Write(encoded); err != nil { + return err + } + _, err = io.WriteString(w, "\n\n") + return err +} + +func readNextSSEDataBlock(reader *bufio.Reader) ([]byte, error) { + var data bytes.Buffer + for { + line, err := reader.ReadBytes('\n') + if err != nil && len(line) == 0 { + if data.Len() > 0 { + return data.Bytes(), nil + } + return nil, err + } + + trimmedLine := bytes.TrimRight(line, "\r\n") + if len(trimmedLine) == 0 { + if data.Len() > 0 { + return data.Bytes(), nil + } + if err != nil { + return nil, err + } + continue + } + + if bytes.HasPrefix(trimmedLine, []byte("data:")) { + payload := bytes.TrimSpace(bytes.TrimPrefix(trimmedLine, []byte("data:"))) + if data.Len() > 0 { + data.WriteByte('\n') + } + data.Write(payload) + } + + if err != nil { + if data.Len() > 0 { + return data.Bytes(), nil + } + return nil, err + } + } +} + +func normalizeResponseOutputValue(value any) string { + switch typed := value.(type) { + case nil: + return "" + case string: + return typed + default: + encoded, err := json.Marshal(value) + if err != nil { + return fmt.Sprint(value) + } + return string(encoded) + } +} + +func normalizeJSONString(value any) string { + switch typed := value.(type) { + case nil: + return "{}" + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return "{}" + } + if json.Valid([]byte(trimmed)) { + return trimmed + } + encoded, _ := json.Marshal(trimmed) + return string(encoded) + default: + encoded, err := json.Marshal(value) + if err != nil { + return "{}" + } + return string(encoded) + } +} + +func normalizeResponseID(value string) string { + value = strings.TrimSpace(value) + switch { + case value == "": + return fmt.Sprintf("resp_%d", time.Now().UnixNano()) + case strings.HasPrefix(value, "resp_"): + return value + default: + return "resp_" + value + } +} + +func normalizeResponseMessageID(responseID string) string { + base := strings.TrimPrefix(strings.TrimSpace(responseID), "resp_") + if base == "" { + base = fmt.Sprintf("%d", time.Now().UnixNano()) + } + return "msg_" + base +} + +func responseCallIDFromValue(value string) string { + value = strings.TrimSpace(value) + switch { + case value == "": + return fmt.Sprintf("call_%d", time.Now().UnixNano()) + case strings.HasPrefix(value, "call_"): + return value + case strings.HasPrefix(value, "fc_"): + return "call_" + strings.TrimPrefix(value, "fc_") + default: + return "call_" + value + } +} + +func responseFunctionIDFromCallID(callID string) string { + callID = strings.TrimSpace(callID) + switch { + case callID == "": + return fmt.Sprintf("fc_%d", time.Now().UnixNano()) + case strings.HasPrefix(callID, "fc_"): + return callID + case strings.HasPrefix(callID, "call_"): + return "fc_" + strings.TrimPrefix(callID, "call_") + default: + return "fc_" + callID + } +} + +func cloneResponseItems(items []map[string]any) []map[string]any { + if len(items) == 0 { + return nil + } + out := make([]map[string]any, 0, len(items)) + for _, item := range items { + out = append(out, cloneMap(item)) + } + return out +} + +func cloneGenericSlice(values []any) []any { + if len(values) == 0 { + return nil + } + out := make([]any, len(values)) + for i, value := range values { + switch typed := value.(type) { + case map[string]any: + out[i] = cloneMap(typed) + case []any: + out[i] = cloneGenericSlice(typed) + default: + out[i] = typed + } + } + return out +} diff --git a/responses_test.go b/responses_test.go new file mode 100644 index 0000000..4f34392 --- /dev/null +++ b/responses_test.go @@ -0,0 +1,138 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestConvertResponsesCreateRequestToOpenAI(t *testing.T) { + store := newResponseStore() + store.Put("resp_prev", []map[string]any{ + { + "type": "message", + "role": "assistant", + "content": []any{ + map[string]any{"type": "output_text", "text": "Previous answer"}, + }, + }, + }) + + body := []byte(`{ + "model": "z-ai/glm-5.1", + "previous_response_id": "resp_prev", + "instructions": "Be concise", + "input": [ + {"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}, + {"type":"function_call_output","call_id":"call_123","output":"done"} + ], + "tools": [{"type":"function","name":"shell_command","description":"Run shell","parameters":{"type":"object","properties":{"command":{"type":"string"}}}}], + "tool_choice": "auto", + "stream": true, + "max_output_tokens": 123 + }`) + + payload, model, stream, conversation, err := convertResponsesCreateRequestToOpenAI(body, store) + if err != nil { + t.Fatalf("convertResponsesCreateRequestToOpenAI returned error: %v", err) + } + + if model != "z-ai/glm-5.1" { + t.Fatalf("unexpected model: %s", model) + } + if !stream { + t.Fatalf("expected stream=true") + } + if payload["max_tokens"] != 123 { + t.Fatalf("expected max_tokens=123, got %#v", payload["max_tokens"]) + } + + messages := payload["messages"].([]any) + if len(messages) != 4 { + t.Fatalf("expected 4 messages, got %d", len(messages)) + } + + systemMessage := messages[1].(map[string]any) + if systemMessage["role"] != "system" { + t.Fatalf("expected system message at index 1, got %#v", systemMessage["role"]) + } + + if len(conversation) != 4 { + t.Fatalf("expected 4 conversation items, got %d", len(conversation)) + } + + tools := payload["tools"].([]any) + tool := tools[0].(map[string]any) + function := tool["function"].(map[string]any) + if function["name"] != "shell_command" { + t.Fatalf("expected tool name shell_command, got %#v", function["name"]) + } +} + +func TestConvertResponsesCreateRequestToOpenAISupportsDeveloperRole(t *testing.T) { + store := newResponseStore() + body := []byte(`{ + "model": "z-ai/glm-5.1", + "input": [ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"Always answer with OK."}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]} + ] + }`) + + payload, _, _, conversation, err := convertResponsesCreateRequestToOpenAI(body, store) + if err != nil { + t.Fatalf("convertResponsesCreateRequestToOpenAI returned error: %v", err) + } + + if got := conversation[0]["role"]; got != "developer" { + t.Fatalf("expected stored role developer, got %#v", got) + } + + messages := payload["messages"].([]any) + developerMessage := messages[0].(map[string]any) + if got := developerMessage["role"]; got != "system" { + t.Fatalf("expected developer role to map to system upstream, got %#v", got) + } +} + +func TestWriteResponsesStream(t *testing.T) { + upstreamStream := strings.Join([]string{ + "data: {\"id\":\"chatcmpl_1\",\"model\":\"z-ai/glm-5.1\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"Hello\"}}]}", + "", + "data: {\"id\":\"chatcmpl_1\",\"model\":\"z-ai/glm-5.1\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"shell_command\",\"arguments\":\"{\\\"command\\\":\\\"pwd\\\"}\"}}]}}]}", + "", + "data: [DONE]", + "", + }, "\n") + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(upstreamStream)), + } + + recorder := httptest.NewRecorder() + store := newResponseStore() + if err := writeResponsesStream(recorder, resp, "z-ai/glm-5.1", nil, store); err != nil { + t.Fatalf("writeResponsesStream returned error: %v", err) + } + + body := recorder.Body.String() + expectedFragments := []string{ + "event: response.created", + "event: response.output_item.added", + "event: response.content_part.added", + "event: response.output_text.delta", + "event: response.function_call_arguments.delta", + "event: response.output_item.done", + "event: response.completed", + "\"output_text\":\"Hello\"", + "\"name\":\"shell_command\"", + } + for _, fragment := range expectedFragments { + if !strings.Contains(body, fragment) { + t.Fatalf("expected stream body to contain %q, body=%s", fragment, body) + } + } +} diff --git a/run_manager.go b/run_manager.go index 97361fc..2217f29 100644 --- a/run_manager.go +++ b/run_manager.go @@ -5,17 +5,30 @@ import ( "errors" "fmt" "log" + "sort" "strings" "sync" "sync/atomic" "time" ) +const ( + warmPoolRecentWindow = 45 * time.Minute + warmPoolMaxModels = 2 + warmPoolMinSwitchAge = 10 * time.Minute +) + type RunManager struct { - cfg Config logger *log.Logger - pools []*tokenPool - next atomic.Uint64 + client *UpstreamClient + + mu sync.RWMutex + cfg Config + pools []*tokenPool + recentModelDemand map[string]modelDemand + + next atomic.Uint64 + warm atomic.Bool stopCh chan struct{} wg sync.WaitGroup @@ -29,17 +42,20 @@ type tokenPool struct { logger *log.Logger mu sync.Mutex - runs map[string]*managedRun // agentID -> current run + runs map[string]*managedRun draining []*managedRun session *cachedSession sessionRefreshCh chan struct{} lastError string cooldownUntil time.Time + disabled bool + lastModelSwitch time.Time } type managedRun struct { id string agentID string + model string startedAt time.Time inflight int requestCount int @@ -55,6 +71,7 @@ type tokenSnapshot struct { Name string `json:"name"` Runs []runSnapshot `json:"runs"` DrainingRuns int `json:"draining_runs"` + SessionModel string `json:"session_model,omitempty"` SessionStatus string `json:"session_status,omitempty"` SessionInstanceID string `json:"session_instance_id,omitempty"` SessionExpiresAt time.Time `json:"session_expires_at,omitempty"` @@ -63,10 +80,13 @@ type tokenSnapshot struct { SessionPollAt time.Time `json:"session_poll_at,omitempty"` CooldownUntil time.Time `json:"cooldown_until,omitempty"` LastError string `json:"last_error,omitempty"` + Disabled bool `json:"disabled,omitempty"` + State string `json:"state"` } type runSnapshot struct { AgentID string `json:"agent_id"` + Model string `json:"model,omitempty"` RunID string `json:"run_id"` StartedAt time.Time `json:"started_at"` Inflight int `json:"inflight"` @@ -80,6 +100,11 @@ type waitingRoomError struct { RetryAfter time.Duration } +type modelDemand struct { + Count int + LastRequested time.Time +} + func (e *waitingRoomError) Error() string { if e == nil { return "freebuff waiting room queued" @@ -103,30 +128,18 @@ func (e *waitingRoomError) Error() string { } func NewRunManager(cfg Config, client *UpstreamClient, logger *log.Logger) *RunManager { - pools := make([]*tokenPool, 0, len(cfg.AuthTokens)) - for index, token := range cfg.AuthTokens { - pools = append(pools, &tokenPool{ - name: fmt.Sprintf("token-%d", index+1), - token: token, - cfg: cfg, - client: client, - runs: make(map[string]*managedRun), - logger: logger, - }) - } - - return &RunManager{ - cfg: cfg, - logger: logger, - pools: pools, - stopCh: make(chan struct{}), - } + manager := &RunManager{ + cfg: cfg, + logger: logger, + client: client, + stopCh: make(chan struct{}), + recentModelDemand: make(map[string]modelDemand), + } + manager.pools = manager.buildPools(cfg, nil) + return manager } func (m *RunManager) Start(ctx context.Context, agentIDs []string) { - // Pre-warm runs for all free agents in background. - // The server is already listening; if a request arrives before - // pre-warming finishes, acquire() will lazily create the run. go m.prewarm(agentIDs) m.wg.Add(1) @@ -138,12 +151,19 @@ func (m *RunManager) Start(ctx context.Context, agentIDs []string) { for { select { case <-ticker.C: - maintainCtx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout) - for _, pool := range m.pools { + cfg := m.currentConfig() + maintainCtx, cancel := context.WithTimeout(context.Background(), cfg.RequestTimeout) + for _, pool := range m.snapshotPools() { if err := pool.maintain(maintainCtx); err != nil { m.logger.Printf("%s: maintenance failed: %v", pool.name, err) } } + if m.warm.CompareAndSwap(false, true) { + if err := m.maintainWarmPool(maintainCtx); err != nil { + m.logger.Printf("warm pool maintenance failed: %v", err) + } + m.warm.Store(false) + } cancel() case <-m.stopCh: return @@ -152,16 +172,66 @@ func (m *RunManager) Start(ctx context.Context, agentIDs []string) { }() } +func (m *RunManager) ApplyConfig(cfg Config) { + m.mu.Lock() + existing := make(map[string]*tokenPool, len(m.pools)) + for _, pool := range m.pools { + pool.cfg = cfg + existing[pool.token] = pool + } + + m.cfg = cfg + m.pools = m.buildPools(cfg, existing) + removed := make([]*tokenPool, 0, len(existing)) + for _, pool := range existing { + removed = append(removed, pool) + } + m.mu.Unlock() + + for _, pool := range removed { + go func(pool *tokenPool) { + ctx, cancel := context.WithTimeout(context.Background(), cfg.RequestTimeout) + defer cancel() + if err := pool.shutdown(ctx); err != nil { + m.logger.Printf("%s: shutdown removed token failed: %v", pool.name, err) + } + }(pool) + } +} + +func (m *RunManager) buildPools(cfg Config, existing map[string]*tokenPool) []*tokenPool { + pools := make([]*tokenPool, 0, len(cfg.AuthTokens)) + for index, token := range cfg.AuthTokens { + if pool := existing[token]; pool != nil { + pool.name = fmt.Sprintf("token-%d", index+1) + pool.cfg = cfg + pools = append(pools, pool) + delete(existing, token) + continue + } + pools = append(pools, &tokenPool{ + name: fmt.Sprintf("token-%d", index+1), + token: token, + cfg: cfg, + client: m.client, + runs: make(map[string]*managedRun), + logger: m.logger, + }) + } + return pools +} + func (m *RunManager) prewarm(agentIDs []string) { - ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout) + cfg := m.currentConfig() + ctx, cancel := context.WithTimeout(context.Background(), cfg.RequestTimeout) defer cancel() - for _, pool := range m.pools { - if _, err := pool.ensureSession(ctx); err != nil { - m.logger.Printf("%s: free session prewarm failed: %v", pool.name, err) + for _, pool := range m.snapshotPools() { + if pool.isDisabled() { + continue } for _, agentID := range agentIDs { - if err := pool.rotateAgent(ctx, agentID); err != nil { + if err := pool.rotateAgent(ctx, agentID, ""); err != nil { m.logger.Printf("%s: prewarm %s failed: %v", pool.name, agentID, err) } else { m.logger.Printf("%s: prewarmed %s", pool.name, agentID) @@ -173,24 +243,29 @@ func (m *RunManager) prewarm(agentIDs []string) { func (m *RunManager) Close(ctx context.Context) { close(m.stopCh) m.wg.Wait() - for _, pool := range m.pools { + for _, pool := range m.snapshotPools() { if err := pool.shutdown(ctx); err != nil { m.logger.Printf("%s: shutdown failed: %v", pool.name, err) } } } -func (m *RunManager) Acquire(ctx context.Context, agentID string) (*runLease, error) { - if len(m.pools) == 0 { +func (m *RunManager) Acquire(ctx context.Context, agentID, model string) (*runLease, error) { + m.noteModelRequest(model) + m.kickWarmPool() + + pools := m.snapshotPools() + if len(pools) == 0 { return nil, errors.New("no auth tokens configured") } - startIndex := int(m.next.Add(1)-1) % len(m.pools) + startIndex := int(m.next.Add(1)-1) % len(pools) var errs []string var waiting []*waitingRoomError - for offset := 0; offset < len(m.pools); offset++ { - pool := m.pools[(startIndex+offset)%len(m.pools)] - lease, err := pool.acquire(ctx, agentID) + var switching []*modelSwitchError + for offset := 0; offset < len(pools); offset++ { + pool := pools[(startIndex+offset)%len(pools)] + lease, err := pool.acquire(ctx, agentID, model) if err == nil { return lease, nil } @@ -198,10 +273,14 @@ func (m *RunManager) Acquire(ctx context.Context, agentID string) (*runLease, er if errors.As(err, &waitingErr) { waiting = append(waiting, waitingErr) } + var switchErr *modelSwitchError + if errors.As(err, &switchErr) { + switching = append(switching, switchErr) + } errs = append(errs, fmt.Sprintf("%s: %v", pool.name, err)) } - if len(waiting) == len(m.pools) && len(waiting) > 0 { + if len(waiting) == len(pools) && len(waiting) > 0 { best := waiting[0] for _, candidate := range waiting[1:] { if candidate != nil && (best == nil || (candidate.Position > 0 && candidate.Position < best.Position)) { @@ -213,6 +292,18 @@ func (m *RunManager) Acquire(ctx context.Context, agentID string) (*runLease, er } } + if len(switching) == len(pools) && len(switching) > 0 { + best := switching[0] + for _, candidate := range switching[1:] { + if candidate != nil && candidate.RetryAfter > 0 && (best == nil || best.RetryAfter <= 0 || candidate.RetryAfter < best.RetryAfter) { + best = candidate + } + } + if best != nil { + return nil, best + } + } + return nil, fmt.Errorf("unable to acquire run from any token (%s)", strings.Join(errs, "; ")) } @@ -238,14 +329,229 @@ func (m *RunManager) Cooldown(lease *runLease, duration time.Duration, reason st } func (m *RunManager) Snapshots() []tokenSnapshot { - snapshots := make([]tokenSnapshot, 0, len(m.pools)) - for _, pool := range m.pools { + pools := m.snapshotPools() + snapshots := make([]tokenSnapshot, 0, len(pools)) + for _, pool := range pools { snapshots = append(snapshots, pool.snapshot()) } return snapshots } -func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, error) { +func (m *RunManager) currentConfig() Config { + m.mu.RLock() + defer m.mu.RUnlock() + return m.cfg +} + +func (m *RunManager) snapshotPools() []*tokenPool { + m.mu.RLock() + defer m.mu.RUnlock() + pools := make([]*tokenPool, len(m.pools)) + copy(pools, m.pools) + return pools +} + +func (m *RunManager) noteModelRequest(model string) { + model = strings.TrimSpace(model) + if model == "" { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + now := time.Now() + for existingModel, demand := range m.recentModelDemand { + if now.Sub(demand.LastRequested) > warmPoolRecentWindow { + delete(m.recentModelDemand, existingModel) + } + } + + demand := m.recentModelDemand[model] + demand.Count++ + demand.LastRequested = now + m.recentModelDemand[model] = demand +} + +func (m *RunManager) hotModels(limit int) []string { + if limit <= 0 { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + now := time.Now() + type scoredModel struct { + Name string + Count int + LastRequested time.Time + } + scored := make([]scoredModel, 0, len(m.recentModelDemand)) + for model, demand := range m.recentModelDemand { + if now.Sub(demand.LastRequested) > warmPoolRecentWindow { + delete(m.recentModelDemand, model) + continue + } + scored = append(scored, scoredModel{ + Name: model, + Count: demand.Count, + LastRequested: demand.LastRequested, + }) + } + + sort.Slice(scored, func(i, j int) bool { + if scored[i].Count != scored[j].Count { + return scored[i].Count > scored[j].Count + } + return scored[i].LastRequested.After(scored[j].LastRequested) + }) + + if len(scored) > limit { + scored = scored[:limit] + } + + models := make([]string, 0, len(scored)) + for _, item := range scored { + models = append(models, item.Name) + } + return models +} + +func (m *RunManager) kickWarmPool() { + if !m.warm.CompareAndSwap(false, true) { + return + } + + go func() { + defer m.warm.Store(false) + cfg := m.currentConfig() + ctx, cancel := context.WithTimeout(context.Background(), cfg.RequestTimeout) + defer cancel() + if err := m.maintainWarmPool(ctx); err != nil { + m.logger.Printf("warm pool maintenance failed: %v", err) + } + }() +} + +func (m *RunManager) maintainWarmPool(ctx context.Context) error { + hotModels := m.hotModels(warmPoolMaxModels) + if len(hotModels) == 0 { + return nil + } + + pools := m.snapshotPools() + if len(pools) == 0 { + return nil + } + + eligiblePools := make([]*tokenPool, 0, len(pools)) + for _, pool := range pools { + if !pool.isDisabled() { + eligiblePools = append(eligiblePools, pool) + } + } + if len(eligiblePools) == 0 { + return nil + } + + desired := desiredWarmCounts(len(eligiblePools), hotModels) + currentCounts := make(map[string]int, len(desired)) + currentModel := make(map[*tokenPool]string, len(eligiblePools)) + excessPools := make([]*tokenPool, 0, len(eligiblePools)) + + for _, pool := range eligiblePools { + model := strings.TrimSpace(pool.currentSessionModel()) + currentModel[pool] = model + if _, tracked := desired[model]; tracked { + currentCounts[model]++ + continue + } + excessPools = append(excessPools, pool) + } + + for _, model := range hotModels { + for currentCounts[model] > desired[model] { + for _, pool := range eligiblePools { + if strings.TrimSpace(currentModel[pool]) != model { + continue + } + excessPools = append(excessPools, pool) + currentModel[pool] = "" + currentCounts[model]-- + break + } + } + } + + used := make(map[*tokenPool]struct{}, len(excessPools)) + for _, model := range hotModels { + for currentCounts[model] < desired[model] { + selected := (*tokenPool)(nil) + for _, candidate := range excessPools { + if _, ok := used[candidate]; ok { + continue + } + selected = candidate + break + } + if selected == nil { + return nil + } + + if err := selected.warmModel(ctx, model); err != nil { + m.logger.Printf("%s: warm %s failed: %v", selected.name, model, err) + used[selected] = struct{}{} + continue + } + used[selected] = struct{}{} + currentCounts[model]++ + } + } + + return nil +} + +func desiredWarmCounts(totalPools int, models []string) map[string]int { + desired := make(map[string]int, len(models)) + if totalPools <= 0 || len(models) == 0 { + return desired + } + + activeModels := models + if len(activeModels) > totalPools { + activeModels = activeModels[:totalPools] + } + + for _, model := range activeModels { + desired[model] = 1 + } + + remaining := totalPools - len(activeModels) + for i := 0; i < remaining; i++ { + model := activeModels[i%len(activeModels)] + desired[model]++ + } + + return desired +} + +func (p *tokenPool) acquire(ctx context.Context, agentID, model string) (*runLease, error) { + p.mu.Lock() + if p.disabled { + lastError := p.lastError + p.mu.Unlock() + if lastError == "" { + lastError = "token disabled" + } + return nil, errors.New(lastError) + } + p.mu.Unlock() + + if err := p.prepareModel(ctx, model); err != nil { + return nil, err + } + p.mu.Lock() if now := time.Now(); now.Before(p.cooldownUntil) { cooldownUntil := p.cooldownUntil @@ -253,16 +559,16 @@ func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, err return nil, fmt.Errorf("token cooling down until %s", cooldownUntil.Format(time.RFC3339)) } run := p.runs[agentID] - needsRotate := run == nil || time.Since(run.startedAt) >= p.cfg.RotationInterval + needsRotate := run == nil || run.model != model || time.Since(run.startedAt) >= p.cfg.RotationInterval p.mu.Unlock() if needsRotate { - if err := p.rotateAgent(ctx, agentID); err != nil { + if err := p.rotateAgent(ctx, agentID, model); err != nil { return nil, err } } - if _, err := p.ensureSession(ctx); err != nil { + if _, err := p.ensureSession(ctx, model); err != nil { return nil, err } @@ -277,9 +583,54 @@ func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, err return &runLease{pool: p, run: run}, nil } +func (p *tokenPool) warmModel(ctx context.Context, model string) error { + model = strings.TrimSpace(model) + if model == "" || p.isDisabled() { + return nil + } + currentModel := strings.TrimSpace(p.currentSessionModel()) + if p.shouldDeferWarmSwitch(model) { + return nil + } + + if currentModel == "" || currentModel == model { + if _, err := p.ensureSession(ctx, model); err == nil { + return nil + } else { + var waitingErr *waitingRoomError + if errors.As(err, &waitingErr) && p.currentSessionModel() == model { + return nil + } + } + } + + if err := p.prepareModel(ctx, model); err != nil { + return err + } + + _, err := p.ensureSession(ctx, model) + if err == nil { + return nil + } + + var waitingErr *waitingRoomError + if errors.As(err, &waitingErr) { + return nil + } + return err +} + func (p *tokenPool) maintain(ctx context.Context) error { - if _, err := p.ensureSession(ctx); err != nil { - p.logger.Printf("%s: refresh free session failed: %v", p.name, err) + if p.isDisabled() { + return nil + } + if model := p.currentSessionModel(); model != "" { + if _, err := p.ensureSession(ctx, model); err != nil { + p.logger.Printf("%s: refresh free session failed: %v", p.name, err) + } + } + if p.isDisabled() { + return nil } p.mu.Lock() @@ -293,7 +644,13 @@ func (p *tokenPool) maintain(ctx context.Context) error { p.mu.Unlock() for _, agentID := range toRotate { - if err := p.rotateAgent(ctx, agentID); err != nil { + model := "" + p.mu.Lock() + if run := p.runs[agentID]; run != nil { + model = run.model + } + p.mu.Unlock() + if err := p.rotateAgent(ctx, agentID, model); err != nil { p.logger.Printf("%s: rotate agent %s failed: %v", p.name, agentID, err) } } @@ -332,8 +689,16 @@ func (p *tokenPool) shutdown(ctx context.Context) error { return nil } -func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { +func (p *tokenPool) rotateAgent(ctx context.Context, agentID, model string) error { p.mu.Lock() + if p.disabled { + lastError := p.lastError + p.mu.Unlock() + if lastError == "" { + lastError = "token disabled" + } + return errors.New(lastError) + } if now := time.Now(); now.Before(p.cooldownUntil) { cooldownUntil := p.cooldownUntil p.mu.Unlock() @@ -343,6 +708,10 @@ func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { runID, err := p.client.StartRun(ctx, p.token, agentID) if err != nil { + if isBannedErrorMessage(err.Error()) { + p.disable("upstream token banned") + return err + } p.mu.Lock() p.lastError = err.Error() p.mu.Unlock() @@ -354,6 +723,7 @@ func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error { p.runs[agentID] = &managedRun{ id: runID, agentID: agentID, + model: model, startedAt: time.Now(), } p.lastError = "" @@ -394,15 +764,15 @@ func (p *tokenPool) finishIfReady(run *managedRun) error { p.mu.Unlock() return nil } - // Only finish if this run is no longer the current run for its agent if current, ok := p.runs[run.agentID]; ok && current == run { p.mu.Unlock() return nil } run.finishing = true + timeout := p.cfg.RequestTimeout p.mu.Unlock() - ctx, cancel := context.WithTimeout(context.Background(), p.cfg.RequestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() if err := p.client.FinishRun(ctx, p.token, run.id, run.requestCount); err != nil { @@ -429,7 +799,6 @@ func (p *tokenPool) invalidate(run *managedRun, reason string) { p.mu.Lock() defer p.mu.Unlock() - // Remove from current runs if it matches if current, ok := p.runs[run.agentID]; ok && current == run { delete(p.runs, run.agentID) } @@ -458,6 +827,44 @@ func (p *tokenPool) markCooldown(duration time.Duration, reason string) { } } +func (p *tokenPool) shouldDeferWarmSwitch(targetModel string) bool { + targetModel = strings.TrimSpace(targetModel) + if targetModel == "" { + return false + } + + p.mu.Lock() + defer p.mu.Unlock() + + if p.disabled { + return true + } + if now := time.Now(); now.Before(p.cooldownUntil) { + return true + } + + currentModel := "" + if p.session != nil { + currentModel = strings.TrimSpace(p.session.model) + } + if currentModel == "" { + for _, run := range p.runs { + if strings.TrimSpace(run.model) != "" { + currentModel = strings.TrimSpace(run.model) + break + } + } + } + if currentModel == "" || currentModel == targetModel { + return false + } + + if !p.lastModelSwitch.IsZero() && time.Since(p.lastModelSwitch) < warmPoolMinSwitchAge { + return true + } + return false +} + func (p *tokenPool) snapshot() tokenSnapshot { p.mu.Lock() defer p.mu.Unlock() @@ -467,8 +874,10 @@ func (p *tokenPool) snapshot() tokenSnapshot { DrainingRuns: len(p.draining), CooldownUntil: p.cooldownUntil, LastError: p.lastError, + Disabled: p.disabled, } if p.session != nil { + snapshot.SessionModel = p.session.model snapshot.SessionStatus = string(p.session.status) snapshot.SessionInstanceID = p.session.instanceID snapshot.SessionExpiresAt = p.session.expiresAt @@ -479,11 +888,55 @@ func (p *tokenPool) snapshot() tokenSnapshot { for agentID, run := range p.runs { snapshot.Runs = append(snapshot.Runs, runSnapshot{ AgentID: agentID, + Model: run.model, RunID: run.id, StartedAt: run.startedAt, Inflight: run.inflight, RequestCount: run.requestCount, }) } + snapshot.State = classifyTokenState(snapshot) return snapshot } + +func (p *tokenPool) disable(reason string) { + p.mu.Lock() + defer p.mu.Unlock() + p.disabled = true + p.session = nil + p.cooldownUntil = time.Time{} + if reason != "" { + p.lastError = reason + } +} + +func (p *tokenPool) isDisabled() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.disabled +} + +func classifyTokenState(snapshot tokenSnapshot) string { + now := time.Now() + switch { + case snapshot.Disabled && strings.Contains(strings.ToLower(snapshot.LastError), "banned"): + return "banned" + case snapshot.Disabled: + return "disabled" + case !snapshot.CooldownUntil.IsZero() && now.Before(snapshot.CooldownUntil): + return "cooling_down" + case snapshot.SessionStatus == string(sessionStatusQueued): + return "queued" + case snapshot.SessionStatus == string(sessionStatusActive): + return "active" + case snapshot.SessionStatus == string(sessionStatusDisabled): + return "disabled" + default: + return "idle" + } +} + +func isBannedErrorMessage(message string) bool { + message = strings.ToLower(strings.TrimSpace(message)) + return strings.Contains(message, `"status":"banned"`) || strings.Contains(message, `"status": "banned"`) || strings.Contains(message, "status\":\"banned") +} diff --git a/server.go b/server.go index 9fad1fc..baf347d 100644 --- a/server.go +++ b/server.go @@ -14,33 +14,56 @@ import ( ) type Server struct { - cfg Config - logger *log.Logger - client *UpstreamClient - runs *RunManager - registry *ModelRegistry - started time.Time + cfgStore *ConfigStore + logger *log.Logger + client *UpstreamClient + runs *RunManager + registry *ModelRegistry + responses *responseStore + started time.Time +} + +type proxyErrorResponse struct { + StatusCode int + Message string + ErrorType string + Code string + RetryAfter time.Duration } func NewServer(cfg Config, logger *log.Logger, registry *ModelRegistry) *Server { - client := NewUpstreamClient(cfg) + cfgStore := NewConfigStore(cfg) + client := NewUpstreamClient(cfgStore) runManager := NewRunManager(cfg, client, logger) return &Server{ - cfg: cfg, - logger: logger, - client: client, - runs: runManager, - registry: registry, - started: time.Now(), + cfgStore: cfgStore, + logger: logger, + client: client, + runs: runManager, + registry: registry, + responses: newResponseStore(), + started: time.Now(), } } +func (s *Server) ApplyConfig(cfg Config) { + current := s.cfgStore.Current() + if current.ListenAddr != "" && cfg.ListenAddr != current.ListenAddr { + s.logger.Printf("LISTEN_ADDR changed from %s to %s but requires restart; keeping current listener", current.ListenAddr, cfg.ListenAddr) + cfg.ListenAddr = current.ListenAddr + } + s.cfgStore.Update(cfg) + s.runs.ApplyConfig(cfg) +} + func (s *Server) Handler() http.Handler { mux := http.NewServeMux() mux.HandleFunc("/healthz", s.handleHealthz) + mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/v1/models", s.handleModels) mux.HandleFunc("/v1/chat/completions", s.handleChatCompletions) + mux.HandleFunc("/v1/responses", s.handleResponses) mux.HandleFunc("/v1/messages", s.handleClaudeMessages) mux.HandleFunc("/v1/messages/count_tokens", s.handleClaudeCountTokens) return s.withMiddleware(mux) @@ -56,11 +79,12 @@ func (s *Server) Shutdown(ctx context.Context) { func (s *Server) withMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if len(s.cfg.APIKeys) > 0 && !s.authorized(r) { + cfg := s.cfgStore.Current() + if len(cfg.APIKeys) > 0 && !s.authorized(r, cfg.APIKeys) { if isClaudeRequestPath(r.URL.Path) { - writeClaudeError(w, http.StatusUnauthorized, "invalid proxy api key", "authentication_error") + writeClaudeErrorDetailed(w, http.StatusUnauthorized, "invalid proxy api key", "authentication_error", "invalid_api_key") } else { - writeOpenAIError(w, http.StatusUnauthorized, "invalid proxy api key", "authentication_error", "") + writeOpenAIError(w, http.StatusUnauthorized, "invalid proxy api key", "authentication_error", "invalid_api_key") } return } @@ -68,9 +92,9 @@ func (s *Server) withMiddleware(next http.Handler) http.Handler { }) } -func (s *Server) authorized(r *http.Request) bool { +func (s *Server) authorized(r *http.Request, apiKeys []string) bool { if apiKey := strings.TrimSpace(r.Header.Get("x-api-key")); apiKey != "" { - if containsString(s.cfg.APIKeys, apiKey) { + if containsString(apiKeys, apiKey) { return true } } @@ -84,7 +108,7 @@ func (s *Server) authorized(r *http.Request) bool { return false } apiKey := strings.TrimSpace(strings.TrimPrefix(authorization, prefix)) - return containsString(s.cfg.APIKeys, apiKey) + return containsString(apiKeys, apiKey) } func isClaudeRequestPath(path string) bool { @@ -98,14 +122,45 @@ func (s *Server) handleHealthz(w http.ResponseWriter, r *http.Request) { } response := map[string]any{ - "ok": true, - "started_at": s.started.UTC(), - "uptime_sec": int(time.Since(s.started).Seconds()), - "token_state": s.runs.Snapshots(), + "ok": true, + "started_at": s.started.UTC(), + "uptime_sec": int(time.Since(s.started).Seconds()), + "summary": summarizeTokenSnapshots(s.runs.Snapshots()), } writeJSON(w, http.StatusOK, response) } +func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed", "invalid_request_error", "") + return + } + + cfg := s.cfgStore.Current() + snapshots := s.runs.Snapshots() + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "started_at": s.started.UTC(), + "uptime_sec": int(time.Since(s.started).Seconds()), + "summary": summarizeTokenSnapshots(snapshots), + "available_models": s.registry.Models(), + "token_state": snapshots, + "config": map[string]any{ + "listen_addr": cfg.ListenAddr, + "upstream_base_url": cfg.UpstreamBaseURL, + "rotation_interval": cfg.RotationInterval.String(), + "request_timeout": cfg.RequestTimeout.String(), + "auth_token_count": len(cfg.AuthTokens), + "api_key_count": len(cfg.APIKeys), + "config_path": cfg.ConfigPath, + "config_format": cfg.ConfigFormat, + "auth_token_dir": cfg.AuthTokenDir, + "loaded_at": cfg.LoadedAt, + "hot_reload_enabled": cfg.ConfigPath != "" || cfg.AuthTokenDir != "", + }, + }) +} + func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed", "invalid_request_error", "") @@ -170,6 +225,44 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { ) } +func (s *Server) handleResponses(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed", "invalid_request_error", "") + return + } + + requestBody, err := io.ReadAll(r.Body) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, "failed to read request body", "invalid_request_error", "") + return + } + + payload, requestedModel, stream, conversation, err := convertResponsesCreateRequestToOpenAI(requestBody, s.responses) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", "") + return + } + + if !s.registry.HasModel(requestedModel) { + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("unsupported model %q", requestedModel), "invalid_request_error", "model_not_found") + return + } + + s.proxyChatRequest( + w, + r, + payload, + requestedModel, + "invalid_request_error", + "server_error", + writeOpenAIError, + writePassthroughError, + func(w http.ResponseWriter, resp *http.Response) error { + return writeResponsesSuccessResponse(w, resp, requestedModel, stream, conversation, s.responses) + }, + ) +} + func (s *Server) handleClaudeMessages(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeClaudeError(w, http.StatusMethodNotAllowed, "method not allowed", "invalid_request_error") @@ -201,7 +294,7 @@ func (s *Server) handleClaudeMessages(w http.ResponseWriter, r *http.Request) { "invalid_request_error", "api_error", func(w http.ResponseWriter, statusCode int, message, errorType, _ string) { - writeClaudeError(w, statusCode, message, errorType) + writeClaudeErrorDetailed(w, statusCode, message, errorType, "") }, writeClaudePassthroughError, func(w http.ResponseWriter, resp *http.Response) error { @@ -257,6 +350,7 @@ func (s *Server) proxyChatRequest( writeSuccess func(http.ResponseWriter, *http.Response) error, ) { startTime := time.Now() + isClaude := isClaudeRequestPath(r.URL.Path) agentID, ok := s.registry.AgentForModel(requestedModel) if !ok { @@ -264,35 +358,25 @@ func (s *Server) proxyChatRequest( return } - for attempt := 0; attempt < 2; attempt++ { - lease, err := s.runs.Acquire(r.Context(), agentID) + cfg := s.cfgStore.Current() + maxAttempts := len(cfg.AuthTokens) + 1 + if maxAttempts < 2 { + maxAttempts = 2 + } + + for attempt := 0; attempt < maxAttempts; attempt++ { + lease, err := s.runs.Acquire(r.Context(), agentID, requestedModel) if err != nil { - var waitingErr *waitingRoomError - if errors.As(err, &waitingErr) { - if waitingErr.RetryAfter > 0 { - w.Header().Set("Retry-After", fmt.Sprintf("%.0f", waitingErr.RetryAfter.Seconds())) - } - writeError(w, http.StatusServiceUnavailable, waitingErr.Error(), serverErrorType, "waiting_room_queued") - return - } - writeError(w, http.StatusBadGateway, "no healthy upstream auth token available", serverErrorType, "") + s.writeProxyError(w, isClaude, mapAcquireError(err, serverErrorType)) return } s.logger.Printf("[%s] Routing request (model: %s) via run: %s", lease.pool.name, requestedModel, lease.run.id) - sessionInstanceID, err := lease.pool.ensureSession(r.Context()) + sessionInstanceID, err := lease.pool.ensureSession(r.Context(), requestedModel) if err != nil { s.runs.Release(lease) - var waitingErr *waitingRoomError - if errors.As(err, &waitingErr) { - if waitingErr.RetryAfter > 0 { - w.Header().Set("Retry-After", fmt.Sprintf("%.0f", waitingErr.RetryAfter.Seconds())) - } - writeError(w, http.StatusServiceUnavailable, waitingErr.Error(), serverErrorType, "waiting_room_queued") - return - } - writeError(w, http.StatusBadGateway, "failed to acquire upstream free session", serverErrorType, "") + s.writeProxyError(w, isClaude, mapSessionAcquireError(err, serverErrorType)) return } @@ -306,7 +390,13 @@ func (s *Server) proxyChatRequest( resp, errorBody, err := s.client.ChatCompletions(r.Context(), lease.pool.token, upstreamBody) if err != nil { s.runs.Release(lease) - writeError(w, http.StatusBadGateway, err.Error(), serverErrorType, "") + s.logger.Printf("[%s] upstream request failed: %v", lease.pool.name, err) + s.writeProxyError(w, isClaude, proxyErrorResponse{ + StatusCode: http.StatusBadGateway, + Message: "upstream request failed", + ErrorType: serverErrorType, + Code: "upstream_request_failed", + }) return } @@ -320,6 +410,21 @@ func (s *Server) proxyChatRequest( return } + message, _, code := extractUpstreamError(errorBody) + if isBannedErrorMessage(string(errorBody)) { + s.logger.Printf("%s: upstream token banned, disabling token", lease.pool.name) + lease.pool.disable("upstream token banned") + s.runs.Release(lease) + continue + } + if strings.TrimSpace(code) == "session_model_mismatch" { + s.logger.Printf("%s: session model mismatch on run %s, rotating run and refreshing session", lease.pool.name, lease.run.id) + lease.pool.invalidateSession(strings.TrimSpace(message)) + s.runs.Invalidate(lease, strings.TrimSpace(message)) + s.runs.Release(lease) + continue + } + if isSessionInvalid(resp.StatusCode, errorBody) { s.logger.Printf("%s: free session invalid, refreshing and retrying", lease.pool.name) lease.pool.invalidateSession(strings.TrimSpace(string(errorBody))) @@ -341,11 +446,180 @@ func (s *Server) proxyChatRequest( s.runs.Release(lease) s.logger.Printf("[%s] upstream error response: %s", lease.pool.name, string(errorBody)) - writeUpstreamError(w, resp.StatusCode, errorBody) + _ = writeUpstreamError + s.writeProxyError(w, isClaude, mapUpstreamProxyError(resp, errorBody, serverErrorType)) + return + } + + _ = writeError + s.writeProxyError(w, isClaude, proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "upstream session is still switching models", + ErrorType: serverErrorType, + Code: "session_switch_in_progress", + RetryAfter: 3 * time.Second, + }) +} + +func (s *Server) writeProxyError(w http.ResponseWriter, isClaude bool, response proxyErrorResponse) { + if response.StatusCode == 0 { + response.StatusCode = http.StatusBadGateway + } + if response.ErrorType == "" { + if isClaude { + response.ErrorType = normalizeClaudeErrorType(response.StatusCode, "") + } else { + response.ErrorType = "upstream_error" + } + } + if response.RetryAfter > 0 { + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", maxDuration(response.RetryAfter, time.Second).Seconds())) + } + if isClaude { + writeClaudeErrorDetailed(w, response.StatusCode, response.Message, response.ErrorType, response.Code) return } + writeOpenAIError(w, response.StatusCode, response.Message, response.ErrorType, response.Code) +} + +func mapAcquireError(err error, serverErrorType string) proxyErrorResponse { + var waitingErr *waitingRoomError + if errors.As(err, &waitingErr) { + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "all free sessions are queued in the Freebuff waiting room", + ErrorType: serverErrorType, + Code: "waiting_room_queued", + RetryAfter: maxDuration(waitingErr.RetryAfter, time.Second), + } + } + var switchErr *modelSwitchError + if errors.As(err, &switchErr) { + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "upstream session is still switching models", + ErrorType: serverErrorType, + Code: "session_switch_in_progress", + RetryAfter: maxDuration(switchErr.RetryAfter, time.Second), + } + } + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "no healthy upstream auth token available", + ErrorType: serverErrorType, + Code: "token_pool_unavailable", + RetryAfter: 5 * time.Second, + } +} + +func mapSessionAcquireError(err error, serverErrorType string) proxyErrorResponse { + var waitingErr *waitingRoomError + if errors.As(err, &waitingErr) { + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "all free sessions are queued in the Freebuff waiting room", + ErrorType: serverErrorType, + Code: "waiting_room_queued", + RetryAfter: maxDuration(waitingErr.RetryAfter, time.Second), + } + } + var switchErr *modelSwitchError + if errors.As(err, &switchErr) { + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "upstream session is still switching models", + ErrorType: serverErrorType, + Code: "session_switch_in_progress", + RetryAfter: maxDuration(switchErr.RetryAfter, time.Second), + } + } + return proxyErrorResponse{ + StatusCode: http.StatusBadGateway, + Message: "failed to acquire upstream free session", + ErrorType: serverErrorType, + Code: "token_pool_unavailable", + } +} + +func mapUpstreamProxyError(resp *http.Response, errorBody []byte, serverErrorType string) proxyErrorResponse { + message, _, code := extractUpstreamError(errorBody) + switch { + case resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden: + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "upstream auth token rejected the request", + ErrorType: serverErrorType, + Code: "upstream_auth_rejected", + RetryAfter: 30 * time.Minute, + } + case resp.StatusCode == http.StatusTooManyRequests: + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "upstream rate limit reached", + ErrorType: serverErrorType, + Code: "upstream_rate_limited", + RetryAfter: maxDuration(retryAfterDuration(resp.Header.Get("Retry-After")), 30*time.Second), + } + case strings.TrimSpace(code) == "waiting_room_queued": + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "all free sessions are queued in the Freebuff waiting room", + ErrorType: serverErrorType, + Code: "waiting_room_queued", + RetryAfter: maxDuration(retryAfterDuration(resp.Header.Get("Retry-After")), 5*time.Second), + } + case strings.TrimSpace(code) == "session_model_mismatch": + return proxyErrorResponse{ + StatusCode: http.StatusServiceUnavailable, + Message: "upstream session is still switching models", + ErrorType: serverErrorType, + Code: "session_switch_in_progress", + RetryAfter: 3 * time.Second, + } + default: + if strings.TrimSpace(message) == "" { + message = "upstream request failed" + } else { + message = "upstream request failed" + } + return proxyErrorResponse{ + StatusCode: http.StatusBadGateway, + Message: message, + ErrorType: serverErrorType, + Code: "upstream_request_failed", + } + } +} - writeError(w, http.StatusBadGateway, "upstream run expired twice in a row", serverErrorType, "") +func summarizeTokenSnapshots(snapshots []tokenSnapshot) map[string]any { + summary := map[string]any{ + "total_tokens": len(snapshots), + "active": 0, + "queued": 0, + "disabled": 0, + "banned": 0, + "cooling_down": 0, + "idle": 0, + "healthy": 0, + "service_ready": false, + } + + for _, snapshot := range snapshots { + state := snapshot.State + if state == "" { + state = classifyTokenState(snapshot) + } + if _, ok := summary[state]; ok { + summary[state] = summary[state].(int) + 1 + } + if state == "active" || state == "idle" { + summary["healthy"] = summary["healthy"].(int) + 1 + summary["service_ready"] = true + } + } + + summary["message"] = fmt.Sprintf("%d healthy, %d queued, %d disabled", summary["healthy"], summary["queued"], summary["disabled"].(int)+summary["banned"].(int)) + return summary } func writeOpenAISuccessResponse(w http.ResponseWriter, resp *http.Response) error { @@ -388,14 +662,9 @@ func isSessionInvalid(statusCode int, errorBody []byte) bool { if statusCode < 400 { return false } - var payload struct { - Error string `json:"error"` - } - if err := json.Unmarshal(errorBody, &payload); err != nil { - return false - } - switch strings.TrimSpace(payload.Error) { - case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired": + _, _, code := extractUpstreamError(errorBody) + switch strings.TrimSpace(code) { + case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired", "session_model_mismatch": return true default: return false diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..8b1be07 --- /dev/null +++ b/server_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "errors" + "net/http" + "testing" + "time" +) + +func TestMapAcquireErrorWaitingRoom(t *testing.T) { + response := mapAcquireError(&waitingRoomError{RetryAfter: 7 * time.Second}, "server_error") + if response.Code != "waiting_room_queued" { + t.Fatalf("expected waiting_room_queued, got %q", response.Code) + } + if response.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("expected 503, got %d", response.StatusCode) + } + if response.RetryAfter < 7*time.Second { + t.Fatalf("expected retry-after >= 7s, got %s", response.RetryAfter) + } +} + +func TestMapAcquireErrorSwitchInProgress(t *testing.T) { + response := mapAcquireError(&modelSwitchError{CurrentModel: "z-ai/glm-5.1", TargetModel: "minimax/minimax-m2.7", RetryAfter: 2 * time.Second}, "server_error") + if response.Code != "session_switch_in_progress" { + t.Fatalf("expected session_switch_in_progress, got %q", response.Code) + } + if response.RetryAfter < time.Second { + t.Fatalf("expected retry-after >= 1s, got %s", response.RetryAfter) + } +} + +func TestMapAcquireErrorFallsBackToTokenPoolUnavailable(t *testing.T) { + response := mapAcquireError(errors.New("boom"), "server_error") + if response.Code != "token_pool_unavailable" { + t.Fatalf("expected token_pool_unavailable, got %q", response.Code) + } +} + +func TestSummarizeTokenSnapshots(t *testing.T) { + summary := summarizeTokenSnapshots([]tokenSnapshot{ + {State: "active"}, + {State: "queued"}, + {State: "banned"}, + {State: "cooling_down"}, + }) + + if ready, ok := summary["service_ready"].(bool); !ok || !ready { + t.Fatalf("expected service_ready=true, got %#v", summary["service_ready"]) + } + if got := summary["healthy"].(int); got != 1 { + t.Fatalf("expected healthy=1, got %d", got) + } + if got := summary["queued"].(int); got != 1 { + t.Fatalf("expected queued=1, got %d", got) + } + if got := summary["banned"].(int); got != 1 { + t.Fatalf("expected banned=1, got %d", got) + } +} diff --git a/upstream.go b/upstream.go index 37838bc..0ff8c24 100644 --- a/upstream.go +++ b/upstream.go @@ -14,26 +14,38 @@ import ( ) type UpstreamClient struct { - baseURL string + cfgStore *ConfigStore httpClient *http.Client - userAgent string } -func NewUpstreamClient(cfg Config) *UpstreamClient { +type cancelOnCloseReadCloser struct { + io.ReadCloser + cancel context.CancelFunc +} + +func (c *cancelOnCloseReadCloser) Close() error { + err := c.ReadCloser.Close() + if c.cancel != nil { + c.cancel() + } + return err +} + +func NewUpstreamClient(cfgStore *ConfigStore) *UpstreamClient { transport := http.DefaultTransport.(*http.Transport).Clone() - if cfg.HTTPProxy != "" { - if proxyURL, err := url.Parse(cfg.HTTPProxy); err == nil { - transport.Proxy = http.ProxyURL(proxyURL) + transport.Proxy = func(req *http.Request) (*url.URL, error) { + cfg := cfgStore.Current() + if strings.TrimSpace(cfg.HTTPProxy) == "" { + return nil, nil } + return url.Parse(cfg.HTTPProxy) } return &UpstreamClient{ - baseURL: cfg.UpstreamBaseURL, + cfgStore: cfgStore, httpClient: &http.Client{ - Timeout: cfg.RequestTimeout, Transport: transport, }, - userAgent: cfg.UserAgent, } } @@ -123,27 +135,41 @@ func (c *UpstreamClient) ChatCompletions(ctx context.Context, authToken string, } func (c *UpstreamClient) doJSON(ctx context.Context, authToken, path string, body []byte) (*http.Response, error) { - requestURL, err := url.JoinPath(c.baseURL, path) + cfg := c.cfgStore.Current() + requestURL, err := url.JoinPath(cfg.UpstreamBaseURL, path) if err != nil { return nil, fmt.Errorf("build upstream url: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(body)) + requestCtx, cancel := c.requestContext(ctx) + + req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, requestURL, bytes.NewReader(body)) if err != nil { + cancel() return nil, fmt.Errorf("create request: %w", err) } req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") - req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("User-Agent", cfg.UserAgent) resp, err := c.httpClient.Do(req) if err != nil { + cancel() return nil, fmt.Errorf("send upstream request: %w", err) } + resp.Body = &cancelOnCloseReadCloser{ReadCloser: resp.Body, cancel: cancel} return resp, nil } +func (c *UpstreamClient) requestContext(ctx context.Context) (context.Context, context.CancelFunc) { + cfg := c.cfgStore.Current() + if cfg.RequestTimeout <= 0 { + return ctx, func() {} + } + return context.WithTimeout(ctx, cfg.RequestTimeout) +} + func retryAfterDuration(headerValue string) time.Duration { headerValue = strings.TrimSpace(headerValue) if headerValue == "" { diff --git a/warm_pool_test.go b/warm_pool_test.go new file mode 100644 index 0000000..3225e1d --- /dev/null +++ b/warm_pool_test.go @@ -0,0 +1,57 @@ +package main + +import ( + "log" + "testing" + "time" +) + +func TestDesiredWarmCounts(t *testing.T) { + got := desiredWarmCounts(4, []string{"z-ai/glm-5.1", "minimax/minimax-m2.7"}) + if got["z-ai/glm-5.1"] != 2 || got["minimax/minimax-m2.7"] != 2 { + t.Fatalf("unexpected desired counts: %#v", got) + } + + got = desiredWarmCounts(3, []string{"a", "b"}) + if got["a"] != 2 || got["b"] != 1 { + t.Fatalf("unexpected desired counts for uneven pool: %#v", got) + } +} + +func TestHotModelsDropsStaleDemand(t *testing.T) { + manager := &RunManager{ + logger: log.New(ioDiscard{}, "", 0), + recentModelDemand: make(map[string]modelDemand), + } + + now := time.Now() + manager.recentModelDemand["stale-model"] = modelDemand{ + Count: 99, + LastRequested: now.Add(-warmPoolRecentWindow - time.Minute), + } + manager.recentModelDemand["recent-low"] = modelDemand{ + Count: 1, + LastRequested: now.Add(-time.Minute), + } + manager.recentModelDemand["recent-high"] = modelDemand{ + Count: 3, + LastRequested: now.Add(-2 * time.Minute), + } + + hot := manager.hotModels(2) + if len(hot) != 2 { + t.Fatalf("expected 2 hot models, got %d (%#v)", len(hot), hot) + } + if hot[0] != "recent-high" || hot[1] != "recent-low" { + t.Fatalf("unexpected hot model ordering: %#v", hot) + } + if _, ok := manager.recentModelDemand["stale-model"]; ok { + t.Fatalf("expected stale demand to be pruned") + } +} + +type ioDiscard struct{} + +func (ioDiscard) Write(p []byte) (int, error) { + return len(p), nil +}