forked from langchain-ai/react-agent
-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathgraph.py
More file actions
139 lines (104 loc) · 4.36 KB
/
graph.py
File metadata and controls
139 lines (104 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Define a custom Reasoning and Action agent.
Works with a chat model with tool calling support.
"""
from datetime import UTC, datetime
from typing import Dict, List, Literal, cast
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.runtime import Runtime
from common.context import Context
from common.tools import get_tools
from common.utils import load_chat_model
from react_agent.state import InputState, State
# Define the function that calls the model
async def call_model(
state: State, runtime: Runtime[Context]
) -> Dict[str, List[AIMessage]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
Args:
state (State): The current state of the conversation.
config (RunnableConfig): Configuration for the model run.
Returns:
dict: A dictionary containing the model's response message.
"""
# Get available tools based on configuration
available_tools = await get_tools()
# Initialize the model with tool binding. Change the model or add more tools here.
model = load_chat_model(runtime.context.model).bind_tools(available_tools)
# Format the system prompt. Customize this to change the agent's behavior.
system_message = runtime.context.system_prompt.format(
system_time=datetime.now(tz=UTC).isoformat()
)
# Get the model's response
response = cast(
AIMessage,
await model.ainvoke(
[{"role": "system", "content": system_message}, *state.messages]
),
)
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages
return {"messages": [response]}
async def dynamic_tools_node(
state: State, runtime: Runtime[Context]
) -> Dict[str, List[ToolMessage]]:
"""Execute tools dynamically based on configuration.
This function gets the available tools based on the current configuration
and executes the requested tool calls from the last message.
"""
# Get available tools based on configuration
available_tools = await get_tools()
# Create a ToolNode with the available tools
tool_node = ToolNode(available_tools)
# Execute the tool node
result = await tool_node.ainvoke(state)
return cast(Dict[str, List[ToolMessage]], result)
# Define a new graph
builder = StateGraph(State, input_schema=InputState, context_schema=Context)
# Define the two nodes we will cycle between
builder.add_node(call_model)
builder.add_node("tools", dynamic_tools_node)
# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")
def route_model_output(state: State) -> Literal["__end__", "tools"]:
"""Determine the next node based on the model's output.
This function checks if the model's last message contains tool calls.
Args:
state (State): The current state of the conversation.
Returns:
str: The name of the next node to call ("__end__" or "tools").
"""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
# Otherwise we execute the requested actions
return "tools"
# Add a conditional edge to determine the next step after `call_model`
builder.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
route_model_output,
)
# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
builder.add_edge("tools", "call_model")
# Compile the builder into an executable graph
graph = builder.compile(name="ReAct Agent")