Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion docs/examples/agents/react/react_using_mellea.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain_community.tools import DuckDuckGoSearchResults

from mellea.backends.tools import MelleaTool
from mellea.stdlib import functional as mfuncs
from mellea.stdlib.context import ChatContext
from mellea.stdlib.frameworks.react import react
from mellea.stdlib.session import start_session
Expand All @@ -28,14 +29,74 @@ class Email(pydantic.BaseModel):
body: str


class TrueOrFalse(pydantic.BaseModel):
"""Response indicating whether the ReACT agent has completed its task."""

answer: bool = pydantic.Field(
description="True if you have enough information to answer the user's question, False if you need more tool calls"
)


async def last_loop_completion_check(
goal, step, context, backend, model_options, turn_num, loop_budget
):
"""Completion check that asks the model if it has the answer on the last iteration.

Only checks on the last iteration (when turn_num == loop_budget) to avoid
unnecessary LLM calls. Returns False for all other iterations.

Note: step.value is guaranteed to exist when this is called.
"""
# Only check on last iteration (and not for unlimited budget)
if loop_budget == -1 or turn_num < loop_budget:
return False

content = mfuncs.chat(
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
context=context,
backend=backend,
format=TrueOrFalse,
)[0].content
have_answer = TrueOrFalse.model_validate_json(content).answer

return have_answer


async def custom_completion_check(
goal, step, context, backend, model_options, turn_num, loop_budget
):
"""Custom completion check combining keyword detection and fallback to last-loop check.

This runs every iteration:
1. First checks if response contains "final answer" for early termination
2. On the last iteration, falls back to asking the model if it has the answer

Note: step.value is guaranteed to exist when this is called.
"""
# Check every iteration for "final answer" keyword (early termination)
if "final answer" in step.value.lower():
return True

# On last iteration, fall back to asking the model if it has the answer
if loop_budget != -1 and turn_num >= loop_budget:
return await last_loop_completion_check(
goal, step, context, backend, model_options, turn_num, loop_budget
)

return False


async def main():
"""Example."""
# Simple version that just searches for an answer.
# Version with custom answer check that terminates early
# when the model says "final answer" and queries the LLM
# if it reaches the loop_budget.
out, _ = await react(
goal="What is the Mellea python library?",
context=ChatContext(),
backend=m.backend,
tools=[search_tool],
answer_check=custom_completion_check,
)
print(out)

Expand All @@ -45,6 +106,7 @@ async def main():
# context=ChatContext(),
# backend=m.backend,
# tools=[search_tool],
# answer_check = custom_completion_check,
# format=Email
# )
# print(out)
Expand Down
44 changes: 44 additions & 0 deletions mellea/stdlib/frameworks/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
history tracking. Raises ``RuntimeError`` if the loop ends without a final answer.
"""

from collections.abc import Awaitable, Callable

import pydantic

# from PIL import Image as PILImage
from mellea.backends.model_options import ModelOption
from mellea.core.backend import Backend, BaseModelSubclass
Expand All @@ -24,6 +28,14 @@
from mellea.stdlib.context import ChatContext


class TrueOrFalse(pydantic.BaseModel):
"""Response indicating whether the ReACT agent has completed its task."""

answer: bool = pydantic.Field(
description="True if you have enough information to answer the user's question, False if you need more tool calls"
)


async def react(
goal: str,
context: ChatContext,
Expand All @@ -36,6 +48,11 @@ async def react(
model_options: dict | None = None,
tools: list[AbstractMelleaTool] | None,
loop_budget: int = 10,
answer_check: Callable[
[str, ModelOutputThunk[str], ChatContext, Backend, dict | None, int, int],
Awaitable[bool],
]
| None = None,
) -> tuple[ModelOutputThunk[str], ChatContext]:
"""Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools.

Expand All @@ -47,6 +64,11 @@ async def react(
model_options: additional model options, which will upsert into the model/backend's defaults.
tools: the list of tools to use
loop_budget: the number of steps allowed; use -1 for unlimited
answer_check: optional callable to determine if the agent has completed its task.
Called every iteration when no tool calls are made and step.value exists (if provided).
Receives (goal, step, context, backend, model_options, turn_num, loop_budget).
Returns bool indicating if the task is complete.
If None, no answer check is performed (loop continues until finalizer or budget exhausted).

Returns:
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
Expand Down Expand Up @@ -105,9 +127,31 @@ async def react(
if tool_res.name == MELLEA_FINALIZER_TOOL:
is_final = True

# Check if the agent has completed its task (runs every iteration if answer_check is provided and there's a value)
# The answer_check function can decide when to actually check based on turn_num and loop_budget
elif not is_final and answer_check and step.value:
have_answer = await answer_check(
goal, step, context, backend, model_options, turn_num, loop_budget
)

if have_answer:
# Create a synthetic finalizer tool response to be consistent with normal loop
finalizer_response = ToolMessage(
role="tool",
content=step.value or "",
tool_output=step.value or "",
name=MELLEA_FINALIZER_TOOL,
args={},
tool=None, # type: ignore
)
tool_responses = [finalizer_response]
context = context.add(finalizer_response)
is_final = True

if is_final:
assert len(tool_responses) == 1, "multiple tools were called with 'final'"

# Apply format if requested
if format is not None:
step, next_context = await mfuncs.aact(
action=ReactThought(),
Expand Down
108 changes: 108 additions & 0 deletions test/stdlib/test_react_direct_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Test ReACT framework handling of direct answers without tool calls."""

import pydantic
import pytest

from mellea.backends.tools import tool
from mellea.stdlib import functional as mfuncs
from mellea.stdlib.context import ChatContext
from mellea.stdlib.frameworks.react import react
from mellea.stdlib.session import start_session


class TrueOrFalse(pydantic.BaseModel):
"""Response indicating whether the ReACT agent has completed its task."""

answer: bool = pydantic.Field(
description="True if you have enough information to answer the user's question, False if you need more tool calls"
)


async def last_loop_completion_check(
goal, step, context, backend, model_options, turn_num, loop_budget
):
"""Completion check that asks the model if it has the answer on the last iteration.

Note: step.value is guaranteed to exist when this is called.
"""
# Only check on last iteration (and not for unlimited budget)
if loop_budget == -1 or turn_num < loop_budget:
return False

content = mfuncs.chat(
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
context=context,
backend=backend,
format=TrueOrFalse,
)[0].content
have_answer = TrueOrFalse.model_validate_json(content).answer
return have_answer


@pytest.mark.ollama
@pytest.mark.llm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.llm
@pytest.mark.e2e

We updated our markers (llm->e2e).

async def test_react_direct_answer_without_tools():
"""Test that ReACT handles direct answers when model doesn't call tools.

This tests the case where the model provides a direct answer in step.value
without making any tool calls. The fix ensures the loop terminates properly
instead of continuing until loop_budget is exhausted.
"""
m = start_session()

# Ask a simple question that doesn't require tools
# The model should provide a direct answer without calling any tools
out, _ = await react(
goal="What is 2 + 2?",
context=ChatContext(),
backend=m.backend,
tools=[], # No tools provided
loop_budget=3, # Should complete in 1 iteration, not exhaust budget
answer_check=last_loop_completion_check,
)

# Verify we got an answer
assert out.value is not None
assert len(out.value) > 0

# The answer should contain "4" or "four"
answer_lower = out.value.lower()
assert "4" in answer_lower or "four" in answer_lower


@pytest.mark.ollama
@pytest.mark.llm
async def test_react_direct_answer_with_unused_tools():
"""Test that ReACT handles direct answers even when tools are available.

This tests the case where tools are provided but the model chooses to
answer directly without using them.
"""
m = start_session()

# Create a dummy tool that won't be needed
@tool
def search_web(query: str) -> str:
"""Search the web for information."""
return "Search results"

# Ask a question that doesn't need the tool
out, _ = await react(
goal="What is the capital of France?",
context=ChatContext(),
backend=m.backend,
tools=[search_web],
loop_budget=3,
answer_check=last_loop_completion_check,
)

# Verify we got an answer
assert out.value is not None
assert len(out.value) > 0

# The answer should mention Paris
answer_lower = out.value.lower()
assert "paris" in answer_lower


# Made with Bob
Loading