Skip to content
65 changes: 46 additions & 19 deletions src/__tests__/drift/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ async function fetchWithRetry(url: string, init: RequestInit, maxRetries = 3): P
try {
const res = await fetch(url, init);
if (RETRYABLE_STATUSES.has(res.status) && attempt < maxRetries - 1) {
console.warn(
`Retry ${attempt + 1}/${maxRetries} after ${res.status} for ${url.slice(0, 80)}`,
);
await res.text(); // consume body to free socket
const backoff = Math.pow(2, attempt) * 1000;
await new Promise((r) => setTimeout(r, backoff));
continue;
Expand All @@ -59,15 +63,21 @@ async function fetchWithRetry(url: string, init: RequestInit, maxRetries = 3): P
// Response parsing
// ---------------------------------------------------------------------------

function assertOk(raw: string, status: number, context: string): void {
/** Redact API keys from query parameters in URLs for safe error messages */
function redactUrl(url: string): string {
return url.replace(/([?&])key=[^&]+/g, "$1key=REDACTED");
}

function assertOk(raw: string, status: number, context: string, url?: string): void {
if (status >= 400) {
throw new Error(`${context}: API returned ${status}: ${raw.slice(0, 300)}`);
const urlSuffix = url ? ` (${redactUrl(url)})` : "";
throw new Error(`${context}: API returned ${status}${urlSuffix}: ${raw.slice(0, 300)}`);
}
}

function parseJsonResponse(raw: string, status: number, context: string): unknown {
function parseJsonResponse(raw: string, status: number, context: string, url?: string): unknown {
if (!raw) throw new Error(`${context}: empty response (status ${status})`);
assertOk(raw, status, context);
assertOk(raw, status, context, url);
try {
return JSON.parse(raw);
} catch {
Expand All @@ -88,14 +98,18 @@ function normalizeLineEndings(text: string): string {
function parseDataOnlySSE(text: string): { data: unknown }[] {
return normalizeLineEndings(text)
.split("\n\n")
.filter((block) => block.startsWith("data: ") && !block.includes("[DONE]"))
.filter((block) => block.startsWith("data: ") && block.trim() !== "data: [DONE]")
.map((block) => {
// Rejoin continuation lines (data split across lines)
const json = block
.split("\n")
.map((line) => (line.startsWith("data: ") ? line.slice(6) : line))
.join("");
return { data: JSON.parse(json) };
try {
return { data: JSON.parse(json) };
} catch (err) {
throw new Error(`Malformed SSE JSON in frame: ${json.slice(0, 100)}`, { cause: err });
}
});
}

Expand All @@ -105,12 +119,27 @@ function parseTypedSSE(text: string): { type: string; data: unknown }[] {
.split("\n\n")
.filter((block) => block.includes("event: ") && block.includes("data: "))
.map((block) => {
const eventMatch = block.match(/^event: (.+)$/m);
const dataMatch = block.match(/^data: (.+)$/m);
return {
type: eventMatch![1],
data: JSON.parse(dataMatch![1]),
};
const eventMatch = block.match(/^event: (.*)$/m);
if (!eventMatch) {
throw new Error("Malformed SSE block: " + block.slice(0, 100));
}
// Handle multi-line data: collect all data lines and join them
const json = block
.split("\n")
.filter((line) => line.startsWith("data: "))
.map((line) => line.slice(6))
.join("");
if (!json) {
throw new Error("Malformed SSE block (no data): " + block.slice(0, 100));
}
try {
return {
type: eventMatch[1],
data: JSON.parse(json),
};
} catch (err) {
throw new Error(`Malformed SSE JSON in frame: ${json.slice(0, 100)}`, { cause: err });
}
});
}

Expand Down Expand Up @@ -339,7 +368,7 @@ export async function geminiNonStreaming(
});

const raw = await res.text();
return { status: res.status, body: parseJsonResponse(raw, res.status, "Gemini"), raw };
return { status: res.status, body: parseJsonResponse(raw, res.status, "Gemini", url), raw };
}

export async function geminiStreaming(
Expand All @@ -361,7 +390,7 @@ export async function geminiStreaming(
});

const raw = await res.text();
assertOk(raw, res.status, "Gemini streaming");
assertOk(raw, res.status, "Gemini streaming", url);
const parsed = parseDataOnlySSE(raw);
const rawEvents = parsed.map((p) => ({
type: "gemini.chunk",
Expand Down Expand Up @@ -590,13 +619,11 @@ export async function listAnthropicModels(apiKey: string): Promise<string[]> {
}

export async function listGeminiModels(apiKey: string): Promise<string[]> {
const res = await fetchWithRetry(
`https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`,
{ method: "GET" },
);
const url = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`;
const res = await fetchWithRetry(url, { method: "GET" });

const raw = await res.text();
const json = parseJsonResponse(raw, res.status, "Gemini model list") as {
const json = parseJsonResponse(raw, res.status, "Gemini model list", url) as {
models: { name: string }[];
};
// Gemini returns "models/gemini-2.5-flash" — strip prefix
Expand Down
23 changes: 17 additions & 6 deletions src/__tests__/ws-gemini-live.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ describe("WebSocket Gemini Live BidiGenerateContent", () => {
expect(msg.toolCall.functionCalls).toHaveLength(1);
expect(msg.toolCall.functionCalls[0].name).toBe("get_weather");
expect(msg.toolCall.functionCalls[0].args).toEqual({ city: "NYC" });
expect(msg.toolCall.functionCalls[0].id).toBe("call_gemini_get_weather_0");
expect(msg.toolCall.functionCalls[0].id).toMatch(/^call_/);

// Separate turnComplete message follows the toolCall
const turnCompleteMsg = JSON.parse(raw[2]);
Expand Down Expand Up @@ -907,7 +907,7 @@ describe("WebSocket Gemini Live BidiGenerateContent", () => {
it("handles user turn with functionResponse that has string response", async () => {
// Fixture that matches a tool call id
const toolResultFixtureStr: Fixture = {
match: { toolCallId: "call_gemini_search_0" },
match: { toolCallId: "call_search_1" },
response: { content: "Result processed" },
};
instance = await createServer([toolResultFixtureStr]);
Expand All @@ -917,13 +917,22 @@ describe("WebSocket Gemini Live BidiGenerateContent", () => {
await ws.waitForMessages(1); // setupComplete

// Send clientContent with functionResponse where response is a string
// Provide explicit id so it matches the fixture's toolCallId
ws.send(
JSON.stringify({
clientContent: {
turns: [
{
role: "user",
parts: [{ functionResponse: { name: "search", response: "string-result" } }],
parts: [
{
functionResponse: {
name: "search",
response: "string-result",
id: "call_search_1",
},
},
],
},
],
turnComplete: true,
Expand All @@ -941,7 +950,7 @@ describe("WebSocket Gemini Live BidiGenerateContent", () => {
it("handles toolResponse with fallback id and string response", async () => {
// Fixture matching on tool call id
const toolResultFixture3: Fixture = {
match: { toolCallId: "call_gemini_lookup_0" },
match: { toolCallId: "call_lookup_1" },
response: { content: "Lookup done" },
};
instance = await createServer([toolResultFixture3]);
Expand All @@ -950,11 +959,13 @@ describe("WebSocket Gemini Live BidiGenerateContent", () => {
ws.send(setupMsg());
await ws.waitForMessages(1); // setupComplete

// Send toolResponse without id (relies on fallback) and with string response
// Send toolResponse with explicit id and string response
ws.send(
JSON.stringify({
toolResponse: {
functionResponses: [{ name: "lookup", response: "string-response-value" }],
functionResponses: [
{ name: "lookup", response: "string-response-value", id: "call_lookup_1" },
],
},
}),
);
Expand Down
4 changes: 2 additions & 2 deletions src/elevenlabs-audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export async function handleElevenLabsTTS(
journal.add({
method,
path,
headers: {},
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: {
status: 503,
Expand Down Expand Up @@ -375,7 +375,7 @@ export async function handleElevenLabsAudio(
journal.add({
method,
path,
headers: {},
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: {
status: 503,
Expand Down
52 changes: 43 additions & 9 deletions src/fal-audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ async function handleQueueSubmit(
journal.add({
method: req.method ?? "POST",
path: pathname,
headers: {},
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: {
status: 503,
Expand Down Expand Up @@ -424,7 +424,7 @@ async function handleQueueSubmit(
journal.add({
method: req.method ?? "POST",
path: pathname,
headers: {},
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: { status: 500, fixture },
});
Expand Down Expand Up @@ -537,7 +537,7 @@ async function tryRecordAudioQueueWalk(args: {
journal.add({
method: req.method ?? "POST",
path: pathname,
headers: {},
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" },
});
Expand Down Expand Up @@ -822,6 +822,10 @@ async function handleSyncRun(
(typeof parsed.text === "string" ? parsed.text : null) ??
"";

// _endpointType is intentionally "fal-audio" — the same value used by
// handleQueueSubmit. Both the synchronous /fal/run/ and asynchronous
// /fal/queue/submit/ paths serve the same fal audio fixtures, so they
// share a single endpoint type for fixture matching purposes.
const syntheticReq: ChatCompletionRequest = {
model: modelId,
messages: [{ role: "user", content: prompt }],
Expand Down Expand Up @@ -861,7 +865,7 @@ async function handleSyncRun(
journal.add({
method: req.method ?? "POST",
path: pathname,
headers: {},
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: {
status: 503,
Expand Down Expand Up @@ -946,7 +950,39 @@ async function handleSyncRun(
return;
}

if (!isAudioResponse(response)) {
// Two valid recorded shapes for fal audio sync runs:
// - AudioResponse: authored fixtures with raw base64 audio that we wrap into
// the fal `{ audio: { url, ... } }` envelope on demand.
// - RawJSONResponse: queue-walk recordings that stored the final fal envelope
// upstream returned (already in fal's `{ audio: { url, ... } }` shape).
let result: Record<string, unknown>;
let resultStatus = 200;
if (isAudioResponse(response)) {
result = audioToFalFile(response);
} else if (isJSONResponse(response)) {
resultStatus = (response as RawJSONResponse).status ?? 200;
const json = (response as RawJSONResponse).json;
if (!json || typeof json !== "object") {
journal.add({
method: req.method ?? "POST",
path: pathname,
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: { status: 500, fixture },
});
res.writeHead(500, { "Content-Type": "application/json" });
res.end(
JSON.stringify({
error: {
message: "Recorded fal audio fixture has non-object json",
type: "server_error",
},
}),
);
return;
}
result = json as Record<string, unknown>;
} else {
journal.add({
method: req.method ?? "POST",
path: pathname,
Expand All @@ -963,15 +999,13 @@ async function handleSyncRun(
return;
}

const result = audioToFalFile(response);

journal.add({
method: req.method ?? "POST",
path: pathname,
headers: flattenHeaders(req.headers),
body: syntheticReq,
response: { status: 200, fixture },
response: { status: resultStatus, fixture },
});
res.writeHead(200, { "Content-Type": "application/json" });
res.writeHead(resultStatus, { "Content-Type": "application/json" });
res.end(JSON.stringify(result));
}
36 changes: 31 additions & 5 deletions src/gemini-interactions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,21 @@ function interactionsUsage(overrides?: ResponseOverrides): {
total_tokens: number;
} {
if (!overrides?.usage) return { total_input_tokens: 0, total_output_tokens: 0, total_tokens: 0 };
const input = overrides.usage.input_tokens ?? overrides.usage.prompt_tokens ?? 0;
const output = overrides.usage.output_tokens ?? overrides.usage.completion_tokens ?? 0;
const input =
overrides.usage.input_tokens ??
overrides.usage.prompt_tokens ??
overrides.usage.promptTokenCount ??
0;
const output =
overrides.usage.output_tokens ??
overrides.usage.completion_tokens ??
overrides.usage.candidatesTokenCount ??
0;
const total = overrides.usage.total_tokens ?? overrides.usage.totalTokenCount ?? input + output;
return {
total_input_tokens: input,
total_output_tokens: output,
total_tokens: input + output,
total_tokens: total,
};
}

Expand Down Expand Up @@ -700,9 +709,14 @@ export async function writeGeminiInteractionsSSEStream(
if (res.writableEnded) return true;
// Data-only SSE (no event: prefix, no [DONE])
res.write(`data: ${JSON.stringify(event)}\n\n`);
onChunkSent?.();
// Only count content deltas for truncateAfterChunks — framing events
// (interaction.start, content.start, content.stop, interaction.complete)
// should not consume chunk budget or trigger the chunk-sent callback.
if (event.event_type === "content.delta") {
onChunkSent?.();
chunkIndex++;
}
if (signal?.aborted) return false;
chunkIndex++;
}

if (!res.writableEnded) {
Expand Down Expand Up @@ -751,6 +765,13 @@ export async function handleGeminiInteractions(

// Convert to ChatCompletionRequest for fixture matching
const completionReq = geminiInteractionsToCompletionRequest(interactionsReq);
// Keep "chat" rather than "gemini-interactions" — the router's endpoint
// compatibility filter (router.ts) treats "chat" as a pass-through that
// matches any unendpointed fixture. Switching to "gemini-interactions"
// would make the request fall into the multimedia guard branch, preventing
// generic chat fixtures from matching and breaking existing users. The
// recorder would also start emitting `endpoint: "gemini-interactions"` in
// recorded fixtures, creating a one-way compatibility break.
completionReq._endpointType = "chat";
completionReq._context = getContext(req);

Expand Down Expand Up @@ -988,6 +1009,11 @@ export async function handleGeminiInteractions(

// Tool call response
if (isToolCallResponse(response)) {
if (response.webSearches?.length) {
logger.warn(
"webSearches in fixture response are not supported for Gemini Interactions API — ignoring",
);
}
const overrides = extractOverrides(response);
const journalEntry = journal.add({
method: req.method ?? "POST",
Expand Down
Loading
Loading