Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/llm/example/tutorial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ const streamText = LLM.stream(request).pipe(
Stream.tap((event) =>
Effect.sync(() => {
if (event.type === "text-delta") process.stdout.write(`\ntext: ${event.text}`)
if (event.type === "request-finish") process.stdout.write(`\nfinish: ${event.reason}\n`)
if (event.type === "finish") process.stdout.write(`\nfinish: ${event.reason}\n`)
}),
),
Stream.runDrain,
Expand Down Expand Up @@ -185,7 +185,7 @@ const FakeProtocol = Protocol.make<FakeBody, string, string, void>({
event: Schema.String,
initial: () => undefined,
step: (_, frame) => Effect.succeed([undefined, [{ type: "text-delta", id: "text-0", text: frame }]] as const),
onHalt: () => [{ type: "request-finish", reason: "stop" }],
onHalt: () => [{ type: "finish", reason: "stop" }],
},
})

Expand Down
1 change: 1 addition & 0 deletions packages/llm/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export type {
ExecutableTools,
Tool as ToolShape,
ToolExecute,
ToolExecuteContext,
Tools,
ToolSchema,
} from "./tool"
Expand Down
2 changes: 1 addition & 1 deletion packages/llm/src/protocols/openai-responses.ts
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ type StepResult = readonly [ParserState, ReadonlyArray<LLMEvent>]
const NO_EVENTS: StepResult["1"] = []

// `response.completed` / `response.incomplete` are clean finishes that emit a
// `request-finish` event; `response.failed` is a hard failure that emits a
// `finish` event; `response.failed` is a hard failure that emits a
// `provider-error`. All three end the stream — kept in one set so `step` and
// the protocol's `terminal` predicate stay in sync.
const TERMINAL_TYPES = new Set(["response.completed", "response.incomplete", "response.failed"])
Expand Down
2 changes: 1 addition & 1 deletion packages/llm/src/protocols/utils/lifecycle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export const finish = (
usage: input.usage,
providerMetadata: input.providerMetadata,
}),
LLMEvent.requestFinish(input),
LLMEvent.finish(input),
)
return { ...stepped, stepStarted: false }
}
Expand Down
44 changes: 25 additions & 19 deletions packages/llm/src/schema/events.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Schema } from "effect"
import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, ResponseID, RouteID, ToolCallID } from "./ids"
import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, RouteID, ToolCallID } from "./ids"
import { ModelRef } from "./options"
import { ToolResultValue } from "./messages"

Expand Down Expand Up @@ -66,14 +66,13 @@ export class Usage extends Schema.Class<Usage>("LLM.Usage")({
get visibleOutputTokens() {
return Math.max(0, (this.outputTokens ?? 0) - (this.reasoningTokens ?? 0))
}

static from(input: UsageInput) {
return input instanceof Usage ? input : new Usage(input)
}
}

export const RequestStart = Schema.Struct({
type: Schema.tag("request-start"),
id: ResponseID,
model: ModelRef,
}).annotate({ identifier: "LLM.Event.RequestStart" })
export type RequestStart = Schema.Schema.Type<typeof RequestStart>
export type UsageInput = Usage | ConstructorParameters<typeof Usage>[0]

export const StepStart = Schema.Struct({
type: Schema.tag("step-start"),
Expand Down Expand Up @@ -185,13 +184,13 @@ export const StepFinish = Schema.Struct({
}).annotate({ identifier: "LLM.Event.StepFinish" })
export type StepFinish = Schema.Schema.Type<typeof StepFinish>

export const RequestFinish = Schema.Struct({
type: Schema.tag("request-finish"),
export const Finish = Schema.Struct({
type: Schema.tag("finish"),
reason: FinishReason,
usage: Schema.optional(Usage),
providerMetadata: Schema.optional(ProviderMetadata),
}).annotate({ identifier: "LLM.Event.RequestFinish" })
export type RequestFinish = Schema.Schema.Type<typeof RequestFinish>
}).annotate({ identifier: "LLM.Event.Finish" })
export type Finish = Schema.Schema.Type<typeof Finish>

export const ProviderErrorEvent = Schema.Struct({
type: Schema.tag("provider-error"),
Expand All @@ -202,7 +201,6 @@ export const ProviderErrorEvent = Schema.Struct({
export type ProviderErrorEvent = Schema.Schema.Type<typeof ProviderErrorEvent>

const llmEventTagged = Schema.Union([
RequestStart,
StepStart,
TextStart,
TextDelta,
Expand All @@ -217,13 +215,15 @@ const llmEventTagged = Schema.Union([
ToolResult,
ToolError,
StepFinish,
RequestFinish,
Finish,
ProviderErrorEvent,
]).pipe(Schema.toTaggedUnion("type"))

type WithID<Event extends { readonly id: unknown }, ID> = Omit<Event, "type" | "id"> & { readonly id: ID | string }
type WithUsage<Event extends { readonly usage?: Usage }> = Omit<Event, "type" | "usage"> & {
readonly usage?: UsageInput
}

const responseID = (value: ResponseID | string) => ResponseID.make(value)
const contentBlockID = (value: ContentBlockID | string) => ContentBlockID.make(value)
const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value)

Expand All @@ -233,7 +233,6 @@ const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value)
* `events.filter(LLMEvent.guards["tool-call"])`.
*/
export const LLMEvent = Object.assign(llmEventTagged, {
requestStart: (input: WithID<RequestStart, ResponseID>) => RequestStart.make({ ...input, id: responseID(input.id) }),
stepStart: StepStart.make,
textStart: (input: WithID<TextStart, ContentBlockID>) => TextStart.make({ ...input, id: contentBlockID(input.id) }),
textDelta: (input: WithID<TextDelta, ContentBlockID>) => TextDelta.make({ ...input, id: contentBlockID(input.id) }),
Expand All @@ -252,11 +251,18 @@ export const LLMEvent = Object.assign(llmEventTagged, {
toolCall: (input: WithID<ToolCall, ToolCallID>) => ToolCall.make({ ...input, id: toolCallID(input.id) }),
toolResult: (input: WithID<ToolResult, ToolCallID>) => ToolResult.make({ ...input, id: toolCallID(input.id) }),
toolError: (input: WithID<ToolError, ToolCallID>) => ToolError.make({ ...input, id: toolCallID(input.id) }),
stepFinish: StepFinish.make,
requestFinish: RequestFinish.make,
stepFinish: (input: WithUsage<StepFinish>) =>
StepFinish.make({
...input,
usage: input.usage === undefined ? undefined : Usage.from(input.usage),
}),
finish: (input: WithUsage<Finish>) =>
Finish.make({
...input,
usage: input.usage === undefined ? undefined : Usage.from(input.usage),
}),
providerError: ProviderErrorEvent.make,
is: {
requestStart: llmEventTagged.guards["request-start"],
stepStart: llmEventTagged.guards["step-start"],
textStart: llmEventTagged.guards["text-start"],
textDelta: llmEventTagged.guards["text-delta"],
Expand All @@ -271,7 +277,7 @@ export const LLMEvent = Object.assign(llmEventTagged, {
toolResult: llmEventTagged.guards["tool-result"],
toolError: llmEventTagged.guards["tool-error"],
stepFinish: llmEventTagged.guards["step-finish"],
requestFinish: llmEventTagged.guards["request-finish"],
finish: llmEventTagged.guards.finish,
providerError: llmEventTagged.guards["provider-error"],
},
})
Expand Down
96 changes: 83 additions & 13 deletions packages/llm/src/tool-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
ToolFailure,
ToolResultPart,
type ToolResultValue,
Usage,
} from "./schema"
import { type AnyTool, type ExecutableTools, type Tools, toDefinitions } from "./tool"

Expand Down Expand Up @@ -72,19 +73,42 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
tools: [...options.request.tools.filter((tool) => !runtimeToolNames.has(tool.name)), ...runtimeTools],
})

const loop = (request: LLMRequest, step: number): Stream.Stream<LLMEvent, LLMError> =>
const loop = (
request: LLMRequest,
step: number,
usage: Usage | undefined,
providerMetadata: ProviderMetadata | undefined,
): Stream.Stream<LLMEvent, LLMError> =>
Stream.unwrap(
Effect.gen(function* () {
const state: StepState = { assistantContent: [], toolCalls: [], finishReason: undefined }
const state: StepState = {
assistantContent: [],
toolCalls: [],
finishReason: undefined,
usage: undefined,
providerMetadata: undefined,
}

const modelStream = options
.stream(request)
.pipe(Stream.map((event) => indexStep(event, step)))
.pipe(Stream.tap((event) => Effect.sync(() => accumulate(state, event))))
.pipe(Stream.filter((event) => event.type !== "finish"))

const continuation = Stream.unwrap(
Effect.gen(function* () {
if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return Stream.empty
if (options.toolExecution === "none") return Stream.empty
const totalUsage = addUsage(usage, state.usage)
const totalProviderMetadata = mergeProviderMetadata(providerMetadata, state.providerMetadata)
const finishStream = Stream.fromIterable([
LLMEvent.finish({
reason: state.finishReason ?? "unknown",
usage: totalUsage,
providerMetadata: totalProviderMetadata,
}),
])

if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return finishStream
if (options.toolExecution === "none") return finishStream

const dispatched = yield* Effect.forEach(
state.toolCalls,
Expand All @@ -93,24 +117,36 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
)
const resultStream = Stream.fromIterable(dispatched.flatMap(([call, result]) => emitEvents(call, result)))

if (!options.stopWhen) return resultStream
if (options.stopWhen({ step, request })) return resultStream
if (!options.stopWhen) return resultStream.pipe(Stream.concat(finishStream))
if (options.stopWhen({ step, request })) return resultStream.pipe(Stream.concat(finishStream))

return resultStream.pipe(Stream.concat(loop(followUpRequest(request, state, dispatched), step + 1)))
return resultStream.pipe(
Stream.concat(
loop(followUpRequest(request, state, dispatched), step + 1, totalUsage, totalProviderMetadata),
),
)
}),
)

return modelStream.pipe(Stream.concat(continuation))
}),
)

return loop(initialRequest, 0)
return loop(initialRequest, 0, undefined, undefined)
}

const indexStep = (event: LLMEvent, index: number): LLMEvent => {
if (event.type === "step-start") return LLMEvent.stepStart({ index })
if (event.type === "step-finish") return LLMEvent.stepFinish({ ...event, index })
return event
}

interface StepState {
assistantContent: ContentPart[]
toolCalls: ToolCallPart[]
finishReason: FinishReason | undefined
usage: Usage | undefined
providerMetadata: ProviderMetadata | undefined
}

const accumulate = (state: StepState, event: LLMEvent) => {
Expand Down Expand Up @@ -154,11 +190,45 @@ const accumulate = (state: StepState, event: LLMEvent) => {
)
return
}
if (event.type === "step-finish" || event.type === "request-finish") {
if (event.type === "step-finish") {
state.finishReason = event.reason === "stop" && state.toolCalls.length > 0 ? "tool-calls" : event.reason
state.usage = addUsage(state.usage, event.usage)
state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata)
return
}
if (event.type === "finish") {
state.finishReason ??= event.reason
state.usage ??= event.usage
state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata)
}
}

const addUsage = (left: Usage | undefined, right: Usage | undefined) => {
if (!left) return right
if (!right) return left
type UsageKey =
| "inputTokens"
| "outputTokens"
| "nonCachedInputTokens"
| "cacheReadInputTokens"
| "cacheWriteInputTokens"
| "reasoningTokens"
| "totalTokens"
const sum = (key: UsageKey) =>
left[key] === undefined && right[key] === undefined ? undefined : Number(left[key] ?? 0) + Number(right[key] ?? 0)

return new Usage({
inputTokens: sum("inputTokens"),
outputTokens: sum("outputTokens"),
nonCachedInputTokens: sum("nonCachedInputTokens"),
cacheReadInputTokens: sum("cacheReadInputTokens"),
cacheWriteInputTokens: sum("cacheWriteInputTokens"),
reasoningTokens: sum("reasoningTokens"),
totalTokens: sum("totalTokens"),
providerMetadata: mergeProviderMetadata(left.providerMetadata, right.providerMetadata),
})
}

const sameProviderMetadata = (left: ProviderMetadata | undefined, right: ProviderMetadata | undefined) =>
left === right || JSON.stringify(left) === JSON.stringify(right)

Expand Down Expand Up @@ -200,17 +270,17 @@ const dispatch = (tools: Tools, call: ToolCallPart): Effect.Effect<ToolResultVal
if (!tool.execute)
return Effect.succeed({ type: "error" as const, value: `Tool has no execute handler: ${call.name}` })

return decodeAndExecute(tool, call.input).pipe(
return decodeAndExecute(tool, call).pipe(
Effect.catchTag("LLM.ToolFailure", (failure) =>
Effect.succeed({ type: "error" as const, value: failure.message } satisfies ToolResultValue),
),
)
}

const decodeAndExecute = (tool: AnyTool, input: unknown): Effect.Effect<ToolResultValue, ToolFailure> =>
tool._decode(input).pipe(
const decodeAndExecute = (tool: AnyTool, call: ToolCallPart): Effect.Effect<ToolResultValue, ToolFailure> =>
tool._decode(call.input).pipe(
Effect.mapError((error) => new ToolFailure({ message: `Invalid tool input: ${error.message}` })),
Effect.flatMap((decoded) => tool.execute!(decoded)),
Effect.flatMap((decoded) => tool.execute!(decoded, { id: call.id, name: call.name })),
Effect.flatMap((value) =>
tool._encode(value).pipe(
Effect.mapError(
Expand Down
11 changes: 8 additions & 3 deletions packages/llm/src/tool.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Effect, JsonSchema, Schema } from "effect"
import type { ToolDefinition as ToolDefinitionClass } from "./schema"
import type { ToolCallPart, ToolDefinition as ToolDefinitionClass } from "./schema"
import { ToolDefinition, ToolFailure } from "./schema"

/**
Expand All @@ -8,9 +8,14 @@ import { ToolDefinition, ToolFailure } from "./schema"
* beyond pure data conversion belongs in the handler closure.
*/
export type ToolSchema<T> = Schema.Codec<T, any, never, never>
export interface ToolExecuteContext {
readonly id: ToolCallPart["id"]
readonly name: ToolCallPart["name"]
}

export type ToolExecute<Parameters extends ToolSchema<any>, Success extends ToolSchema<any>> = (
params: Schema.Schema.Type<Parameters>,
context?: ToolExecuteContext,
) => Effect.Effect<Schema.Schema.Type<Success>, ToolFailure>

/**
Expand Down Expand Up @@ -61,7 +66,7 @@ type TypedToolConfig = {
type DynamicToolConfig = {
readonly description: string
readonly jsonSchema: JsonSchema.JsonSchema
readonly execute?: (params: unknown) => Effect.Effect<unknown, ToolFailure>
readonly execute?: (params: unknown, context?: ToolExecuteContext) => Effect.Effect<unknown, ToolFailure>
}

/**
Expand Down Expand Up @@ -110,7 +115,7 @@ export function make<Parameters extends ToolSchema<any>, Success extends ToolSch
export function make(config: {
readonly description: string
readonly jsonSchema: JsonSchema.JsonSchema
readonly execute: (params: unknown) => Effect.Effect<unknown, ToolFailure>
readonly execute: (params: unknown, context?: ToolExecuteContext) => Effect.Effect<unknown, ToolFailure>
}): AnyExecutableTool
export function make(config: {
readonly description: string
Expand Down
6 changes: 3 additions & 3 deletions packages/llm/test/adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const request = LLM.request({

const raiseEvent = (event: FakeEvent): import("../src/schema").LLMEvent =>
event.type === "finish"
? { type: "request-finish", reason: event.reason }
? { type: "finish", reason: event.reason }
: { type: "text-delta", id: "text-0", text: event.text }

const fakeProtocol = Protocol.make<FakeBody, FakeEvent, FakeEvent, void>({
Expand Down Expand Up @@ -112,8 +112,8 @@ describe("llm route", () => {
const events = Array.from(yield* llm.stream(request).pipe(Stream.runCollect))
const response = yield* llm.generate(request)

expect(events.map((event) => event.type)).toEqual(["text-delta", "request-finish"])
expect(response.events.map((event) => event.type)).toEqual(["text-delta", "request-finish"])
expect(events.map((event) => event.type)).toEqual(["text-delta", "finish"])
expect(response.events.map((event) => event.type)).toEqual(["text-delta", "finish"])
}),
)

Expand Down
2 changes: 1 addition & 1 deletion packages/llm/test/llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ describe("llm constructors", () => {
LLMResponse.text({
events: [
{ type: "text-delta", id: "text-0", text: "hi" },
{ type: "request-finish", reason: "stop" },
{ type: "finish", reason: "stop" },
],
}),
).toBe("hi")
Expand Down
Loading
Loading