diff --git a/.serena/.gitignore b/.serena/.gitignore new file mode 100644 index 00000000..2e510aff --- /dev/null +++ b/.serena/.gitignore @@ -0,0 +1,2 @@ +/cache +/project.local.yml diff --git a/.serena/project.yml b/.serena/project.yml new file mode 100644 index 00000000..9023f351 --- /dev/null +++ b/.serena/project.yml @@ -0,0 +1,152 @@ +# the name by which the project can be referenced within Serena +project_name: "Mini-Agent" + + +# list of languages for which language servers are started; choose from: +# al bash clojure cpp csharp +# csharp_omnisharp dart elixir elm erlang +# fortran fsharp go groovy haskell +# java julia kotlin lua markdown +# matlab nix pascal perl php +# php_phpactor powershell python python_jedi r +# rego ruby ruby_solargraph rust scala +# swift terraform toml typescript typescript_vts +# vue yaml zig +# (This list may be outdated. For the current list, see values of Language enum here: +# https://github.com/oraios/serena/blob/main/src/solidlsp/ls_config.py +# For some languages, there are alternative language servers, e.g. csharp_omnisharp, ruby_solargraph.) +# Note: +# - For C, use cpp +# - For JavaScript, use typescript +# - For Free Pascal/Lazarus, use pascal +# Special requirements: +# Some languages require additional setup/installations. +# See here for details: https://oraios.github.io/serena/01-about/020_programming-languages.html#language-servers +# When using multiple languages, the first language server that supports a given file will be used for that file. +# The first language is the default language and the respective language server will be used as a fallback. +# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored. +languages: +- python + +# the encoding used by text files in the project +# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings +encoding: "utf-8" + +# line ending convention to use when writing source files. +# Possible values: unset (use global setting), "lf", "crlf", or "native" (platform default) +# This does not affect Serena's own files (e.g. memories and configuration files), which always use native line endings. +line_ending: + +# The language backend to use for this project. +# If not set, the global setting from serena_config.yml is used. +# Valid values: LSP, JetBrains +# Note: the backend is fixed at startup. If a project with a different backend +# is activated post-init, an error will be returned. +language_backend: + +# whether to use project's .gitignore files to ignore files +ignore_all_files_in_gitignore: true + +# advanced configuration option allowing to configure language server-specific options. +# Maps the language key to the options. +# Have a look at the docstring of the constructors of the LS implementations within solidlsp (e.g., for C# or PHP) to see which options are available. +# No documentation on options means no options are available. +ls_specific_settings: {} + +# list of additional paths to ignore in this project. +# Same syntax as gitignore, so you can use * and **. +# Note: global ignored_paths from serena_config.yml are also applied additively. +ignored_paths: [] + +# whether the project is in read-only mode +# If set to true, all editing tools will be disabled and attempts to use them will result in an error +# Added on 2025-04-18 +read_only: false + +# list of tool names to exclude. +# This extends the existing exclusions (e.g. from the global configuration) +# +# Below is the complete list of tools for convenience. +# To make sure you have the latest list of tools, and to view their descriptions, +# execute `uv run scripts/print_tool_overview.py`. +# +# * `activate_project`: Activates a project by name. +# * `check_onboarding_performed`: Checks whether project onboarding was already performed. +# * `create_text_file`: Creates/overwrites a file in the project directory. +# * `delete_lines`: Deletes a range of lines within a file. +# * `delete_memory`: Deletes a memory from Serena's project-specific memory store. +# * `execute_shell_command`: Executes a shell command. +# * `find_referencing_code_snippets`: Finds code snippets in which the symbol at the given location is referenced. +# * `find_referencing_symbols`: Finds symbols that reference the symbol at the given location (optionally filtered by type). +# * `find_symbol`: Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type). +# * `get_current_config`: Prints the current configuration of the agent, including the active and available projects, tools, contexts, and modes. +# * `get_symbols_overview`: Gets an overview of the top-level symbols defined in a given file. +# * `initial_instructions`: Gets the initial instructions for the current project. +# Should only be used in settings where the system prompt cannot be set, +# e.g. in clients you have no control over, like Claude Desktop. +# * `insert_after_symbol`: Inserts content after the end of the definition of a given symbol. +# * `insert_at_line`: Inserts content at a given line in a file. +# * `insert_before_symbol`: Inserts content before the beginning of the definition of a given symbol. +# * `list_dir`: Lists files and directories in the given directory (optionally with recursion). +# * `list_memories`: Lists memories in Serena's project-specific memory store. +# * `onboarding`: Performs onboarding (identifying the project structure and essential tasks, e.g. for testing or building). +# * `prepare_for_new_conversation`: Provides instructions for preparing for a new conversation (in order to continue with the necessary context). +# * `read_file`: Reads a file within the project directory. +# * `read_memory`: Reads the memory with the given name from Serena's project-specific memory store. +# * `remove_project`: Removes a project from the Serena configuration. +# * `replace_lines`: Replaces a range of lines within a file with new content. +# * `replace_symbol_body`: Replaces the full definition of a symbol. +# * `restart_language_server`: Restarts the language server, may be necessary when edits not through Serena happen. +# * `search_for_pattern`: Performs a search for a pattern in the project. +# * `summarize_changes`: Provides instructions for summarizing the changes made to the codebase. +# * `switch_modes`: Activates modes by providing a list of their names +# * `think_about_collected_information`: Thinking tool for pondering the completeness of collected information. +# * `think_about_task_adherence`: Thinking tool for determining whether the agent is still on track with the current task. +# * `think_about_whether_you_are_done`: Thinking tool for determining whether the task is truly completed. +# * `write_memory`: Writes a named memory (for future reference) to Serena's project-specific memory store. +excluded_tools: [] + +# list of tools to include that would otherwise be disabled (particularly optional tools that are disabled by default). +# This extends the existing inclusions (e.g. from the global configuration). +included_optional_tools: [] + +# fixed set of tools to use as the base tool set (if non-empty), replacing Serena's default set of tools. +# This cannot be combined with non-empty excluded_tools or included_optional_tools. +fixed_tools: [] + +# list of mode names to that are always to be included in the set of active modes +# The full set of modes to be activated is base_modes + default_modes. +# If the setting is undefined, the base_modes from the global configuration (serena_config.yml) apply. +# Otherwise, this setting overrides the global configuration. +# Set this to [] to disable base modes for this project. +# Set this to a list of mode names to always include the respective modes for this project. +base_modes: + +# list of mode names that are to be activated by default. +# The full set of modes to be activated is base_modes + default_modes. +# If the setting is undefined, the default_modes from the global configuration (serena_config.yml) apply. +# Otherwise, this overrides the setting from the global configuration (serena_config.yml). +# This setting can, in turn, be overridden by CLI parameters (--mode). +default_modes: + +# initial prompt for the project. It will always be given to the LLM upon activating the project +# (contrary to the memories, which are loaded on demand). +initial_prompt: "" + +# time budget (seconds) per tool call for the retrieval of additional symbol information +# such as docstrings or parameter information. +# This overrides the corresponding setting in the global configuration; see the documentation there. +# If null or missing, use the setting from the global configuration. +symbol_info_budget: + +# list of regex patterns which, when matched, mark a memory entry as read‑only. +# Extends the list from the global configuration, merging the two lists. +read_only_memory_patterns: [] + +# list of regex patterns for memories to completely ignore. +# Matching memories will not appear in list_memories or activate_project output +# and cannot be accessed via read_memory or write_memory. +# To access ignored memory files, use the read_file tool on the raw file path. +# Extends the list from the global configuration, merging the two lists. +# Example: ["_archive/.*", "_episodes/.*"] +ignored_memory_patterns: [] diff --git a/mini_agent/agent.py b/mini_agent/agent.py index b7d7feab..aa32a188 100644 --- a/mini_agent/agent.py +++ b/mini_agent/agent.py @@ -2,6 +2,7 @@ import asyncio import json +import sys from pathlib import Path from time import perf_counter from typing import Optional @@ -53,11 +54,13 @@ def __init__( max_steps: int = 50, workspace_dir: str = "./workspace", token_limit: int = 80000, # Summary triggered when tokens exceed this value + stream: bool = True, # Enable streaming by default ): self.llm = llm_client self.tools = {tool.name: tool for tool in tools} self.max_steps = max_steps self.token_limit = token_limit + self.stream = stream self.workspace_dir = Path(workspace_dir) # Cancellation event for interrupting agent execution (set externally, e.g., by Esc key) self.cancel_event: Optional[asyncio.Event] = None @@ -318,6 +321,325 @@ async def _create_summary(self, messages: list[Message], round_num: int) -> str: # Use simple text summary on failure return summary_content + async def _execute_tool_calls( + self, + tool_calls: list, + assistant_msg: Message, + ) -> str | None: + """Execute a list of tool calls and add results to message history. + + Args: + tool_calls: List of ToolCall objects to execute + assistant_msg: The assistant message containing the tool calls + + Returns: + None if all tool calls executed successfully or error message if cancelled + """ + for tool_call in tool_calls: + tool_call_id = tool_call.id + function_name = tool_call.function.name + arguments = tool_call.function.arguments + + # Tool call header + print(f"\n{Colors.BRIGHT_YELLOW}πŸ”§ Tool Call:{Colors.RESET} {Colors.BOLD}{Colors.CYAN}{function_name}{Colors.RESET}") + + # Arguments (formatted display) + print(f"{Colors.DIM} Arguments:{Colors.RESET}") + # Truncate each argument value to avoid overly long output + truncated_args = {} + for key, value in arguments.items(): + value_str = str(value) + if len(value_str) > 200: + truncated_args[key] = value_str[:200] + "..." + else: + truncated_args[key] = value + args_json = json.dumps(truncated_args, indent=2, ensure_ascii=False) + for line in args_json.split("\n"): + print(f" {Colors.DIM}{line}{Colors.RESET}") + + # Execute tool + if function_name not in self.tools: + result = ToolResult( + success=False, + content="", + error=f"Unknown tool: {function_name}", + ) + else: + try: + tool = self.tools[function_name] + result = await tool.execute(**arguments) + except Exception as e: + # Catch all exceptions during tool execution, convert to failed ToolResult + import traceback + + error_detail = f"{type(e).__name__}: {str(e)}" + error_trace = traceback.format_exc() + result = ToolResult( + success=False, + content="", + error=f"Tool execution failed: {error_detail}\n\nTraceback:\n{error_trace}", + ) + + # Log tool execution result + self.logger.log_tool_result( + tool_name=function_name, + arguments=arguments, + result_success=result.success, + result_content=result.content if result.success else None, + result_error=result.error if not result.success else None, + ) + + # Print result + if result.success: + result_text = result.content + if len(result_text) > 300: + result_text = result_text[:300] + f"{Colors.DIM}...{Colors.RESET}" + print(f"{Colors.BRIGHT_GREEN}βœ“ Result:{Colors.RESET} {result_text}") + else: + print(f"{Colors.BRIGHT_RED}βœ— Error:{Colors.RESET} {Colors.RED}{result.error}{Colors.RESET}") + + # Add tool result message + tool_msg = Message( + role="tool", + content=result.content if result.success else f"Error: {result.error}", + tool_call_id=tool_call_id, + name=function_name, + ) + self.messages.append(tool_msg) + + # Check for cancellation after each tool execution + if self._check_cancelled(): + self._cleanup_incomplete_messages() + cancel_msg = "Task cancelled by user." + print(f"\n{Colors.BRIGHT_YELLOW}⚠️ {cancel_msg}{Colors.RESET}") + return cancel_msg + + return None + + + async def _run_step_stream( + self, + tool_list: list, + step_start_time: float, + run_start_time: float, + ) -> str | None: + """Run a single agent step using streaming. + + Streams LLM output in real-time, buffers tool calls until complete, + then executes them once the full response is received. + + Args: + tool_list: List of available tools + step_start_time: Start time of this step + run_start_time: Start time of the entire run + + Returns: + None to continue to next step, or a string (content/cancel/error) to return + """ + from .retry import RetryExhaustedError + from .schema import FunctionCall, ToolCall + + # Buffers for accumulating response + thinking_content = "" + text_content = "" + tool_calls_buffer: dict[str, dict] = {} # id -> {name, arguments} + + finish_reason = "stop" + total_usage = None + + # Print header for streaming output + sys.stdout.write(f"\n{Colors.BOLD}{Colors.MAGENTA}🧠 Thinking:{Colors.RESET}\n") + sys.stdout.flush() + + try: + stream = await self.llm.generate_stream(messages=self.messages, tools=tool_list) + async for chunk in stream: + if chunk.type == "thinking": + # Print thinking as it arrives (write+flush for immediate display) + sys.stdout.write(chunk.text or "") + sys.stdout.flush() + thinking_content += chunk.text or "" + elif chunk.type == "content": + # Content arrives after thinking ends + sys.stdout.write(f"\n{Colors.BOLD}{Colors.BRIGHT_BLUE}πŸ€– Assistant:{Colors.RESET}\n") + sys.stdout.flush() + sys.stdout.write(chunk.text or "") + sys.stdout.flush() + text_content += chunk.text or "" + elif chunk.type == "tool_call_delta": + # Partial tool call - accumulate arguments + tid = chunk.tool_call_id + if tid not in tool_calls_buffer: + tool_calls_buffer[tid] = {"name": "", "arguments": ""} + tool_calls_buffer[tid]["arguments"] += chunk.arguments or "" + elif chunk.type == "tool_call_complete": + # Complete tool call received + tc = chunk.tool_call + tool_calls_buffer[tc.id] = { + "name": tc.function.name, + "arguments": tc.function.arguments, + } + elif chunk.type == "done": + finish_reason = chunk.finish_reason or "stop" + total_usage = chunk.usage + + except Exception as e: + if isinstance(e, RetryExhaustedError): + error_msg = f"LLM call failed after {e.attempts} retries\nLast error: {str(e.last_exception)}" + print(f"\n{Colors.BRIGHT_RED}❌ Retry failed:{Colors.RESET} {error_msg}") + else: + error_msg = f"LLM call failed: {str(e)}" + print(f"\n{Colors.BRIGHT_RED}❌ Error:{Colors.RESET} {error_msg}") + return error_msg + + # Accumulate token usage + if total_usage: + self.api_total_tokens = total_usage.total_tokens + + # Build final tool calls list + final_tool_calls = None + if tool_calls_buffer: + final_tool_calls = [ + ToolCall( + id=tid, + type="function", + function=FunctionCall( + name=data["name"], + arguments=data["arguments"], + ), + ) + for tid, data in tool_calls_buffer.items() + ] + + # Log LLM response + self.logger.log_response( + content=text_content, + thinking=thinking_content if thinking_content else None, + tool_calls=final_tool_calls, + finish_reason=finish_reason, + ) + + # Add assistant message to history + assistant_msg = Message( + role="assistant", + content=text_content, + thinking=thinking_content if thinking_content else None, + tool_calls=final_tool_calls, + ) + self.messages.append(assistant_msg) + + # Content was already streamed live - just add trailing newline + print() # End the streaming output line + + # Check if task is complete (no tool calls) + if not final_tool_calls: + step_elapsed = perf_counter() - step_start_time + total_elapsed = perf_counter() - run_start_time + print(f"\n{Colors.DIM}⏱️ Step completed in {step_elapsed:.2f}s (total: {total_elapsed:.2f}s){Colors.RESET}") + return text_content + + # Check for cancellation before executing tools + if self._check_cancelled(): + self._cleanup_incomplete_messages() + print(f"\n{Colors.BRIGHT_YELLOW}⚠️ Task cancelled by user.{Colors.RESET}") + return "Task cancelled by user." + + # Execute tool calls + cancel_result = await self._execute_tool_calls(final_tool_calls, assistant_msg) + if cancel_result: + return cancel_result + + step_elapsed = perf_counter() - step_start_time + total_elapsed = perf_counter() - run_start_time + print(f"\n{Colors.DIM}⏱️ Step completed in {step_elapsed:.2f}s (total: {total_elapsed:.2f}s){Colors.RESET}") + + # Increment step counter and continue loop + return None + + async def _run_step_nonstream( + self, + tool_list: list, + step_start_time: float, + run_start_time: float, + ) -> str | None: + """Run a single agent step using non-streaming generate(). + + Args: + tool_list: List of available tools + step_start_time: Start time of this step + run_start_time: Start time of the entire run + + Returns: + None to continue to next step, or a string (content/cancel/error) to return + """ + from .retry import RetryExhaustedError + + try: + response = await self.llm.generate(messages=self.messages, tools=tool_list) + except Exception as e: + if isinstance(e, RetryExhaustedError): + error_msg = f"LLM call failed after {e.attempts} retries\nLast error: {str(e.last_exception)}" + print(f"\n{Colors.BRIGHT_RED}❌ Retry failed:{Colors.RESET} {error_msg}") + else: + error_msg = f"LLM call failed: {str(e)}" + print(f"\n{Colors.BRIGHT_RED}❌ Error:{Colors.RESET} {error_msg}") + return error_msg + + # Accumulate API reported token usage + if response.usage: + self.api_total_tokens = response.usage.total_tokens + + # Log LLM response + self.logger.log_response( + content=response.content, + thinking=response.thinking, + tool_calls=response.tool_calls, + finish_reason=response.finish_reason, + ) + + # Add assistant message + assistant_msg = Message( + role="assistant", + content=response.content, + thinking=response.thinking, + tool_calls=response.tool_calls, + ) + self.messages.append(assistant_msg) + + # Print thinking if present + if response.thinking: + print(f"\n{Colors.BOLD}{Colors.MAGENTA}🧠 Thinking:{Colors.RESET}") + print(f"{Colors.DIM}{response.thinking}{Colors.RESET}") + + # Print assistant response + if response.content: + print(f"\n{Colors.BOLD}{Colors.BRIGHT_BLUE}πŸ€– Assistant:{Colors.RESET}") + print(f"{response.content}") + + # Check if task is complete (no tool calls) + if not response.tool_calls: + step_elapsed = perf_counter() - step_start_time + total_elapsed = perf_counter() - run_start_time + print(f"\n{Colors.DIM}⏱️ Step completed in {step_elapsed:.2f}s (total: {total_elapsed:.2f}s){Colors.RESET}") + return response.content + + # Check for cancellation before executing tools + if self._check_cancelled(): + self._cleanup_incomplete_messages() + print(f"\n{Colors.BRIGHT_YELLOW}⚠️ Task cancelled by user.{Colors.RESET}") + return "Task cancelled by user." + + # Execute tool calls + cancel_result = await self._execute_tool_calls(response.tool_calls, assistant_msg) + if cancel_result: + return cancel_result + + step_elapsed = perf_counter() - step_start_time + total_elapsed = perf_counter() - run_start_time + print(f"\n{Colors.DIM}⏱️ Step completed in {step_elapsed:.2f}s (total: {total_elapsed:.2f}s){Colors.RESET}") + + return None + async def run(self, cancel_event: Optional[asyncio.Event] = None) -> str: """Execute agent loop until task is complete or max steps reached. @@ -368,150 +690,27 @@ async def run(self, cancel_event: Optional[asyncio.Event] = None) -> str: # Log LLM request and call LLM with Tool objects directly self.logger.log_request(messages=self.messages, tools=tool_list) - try: - response = await self.llm.generate(messages=self.messages, tools=tool_list) - except Exception as e: - # Check if it's a retry exhausted error - from .retry import RetryExhaustedError - - if isinstance(e, RetryExhaustedError): - error_msg = f"LLM call failed after {e.attempts} retries\nLast error: {str(e.last_exception)}" - print(f"\n{Colors.BRIGHT_RED}❌ Retry failed:{Colors.RESET} {error_msg}") - else: - error_msg = f"LLM call failed: {str(e)}" - print(f"\n{Colors.BRIGHT_RED}❌ Error:{Colors.RESET} {error_msg}") - return error_msg - - # Accumulate API reported token usage - if response.usage: - self.api_total_tokens = response.usage.total_tokens - - # Log LLM response - self.logger.log_response( - content=response.content, - thinking=response.thinking, - tool_calls=response.tool_calls, - finish_reason=response.finish_reason, - ) - - # Add assistant message - assistant_msg = Message( - role="assistant", - content=response.content, - thinking=response.thinking, - tool_calls=response.tool_calls, - ) - self.messages.append(assistant_msg) - - # Print thinking if present - if response.thinking: - print(f"\n{Colors.BOLD}{Colors.MAGENTA}🧠 Thinking:{Colors.RESET}") - print(f"{Colors.DIM}{response.thinking}{Colors.RESET}") - - # Print assistant response - if response.content: - print(f"\n{Colors.BOLD}{Colors.BRIGHT_BLUE}πŸ€– Assistant:{Colors.RESET}") - print(f"{response.content}") - - # Check if task is complete (no tool calls) - if not response.tool_calls: - step_elapsed = perf_counter() - step_start_time - total_elapsed = perf_counter() - run_start_time - print(f"\n{Colors.DIM}⏱️ Step {step + 1} completed in {step_elapsed:.2f}s (total: {total_elapsed:.2f}s){Colors.RESET}") - return response.content - - # Check for cancellation before executing tools - if self._check_cancelled(): - self._cleanup_incomplete_messages() - cancel_msg = "Task cancelled by user." - print(f"\n{Colors.BRIGHT_YELLOW}⚠️ {cancel_msg}{Colors.RESET}") - return cancel_msg - - # Execute tool calls - for tool_call in response.tool_calls: - tool_call_id = tool_call.id - function_name = tool_call.function.name - arguments = tool_call.function.arguments - - # Tool call header - print(f"\n{Colors.BRIGHT_YELLOW}πŸ”§ Tool Call:{Colors.RESET} {Colors.BOLD}{Colors.CYAN}{function_name}{Colors.RESET}") - - # Arguments (formatted display) - print(f"{Colors.DIM} Arguments:{Colors.RESET}") - # Truncate each argument value to avoid overly long output - truncated_args = {} - for key, value in arguments.items(): - value_str = str(value) - if len(value_str) > 200: - truncated_args[key] = value_str[:200] + "..." - else: - truncated_args[key] = value - args_json = json.dumps(truncated_args, indent=2, ensure_ascii=False) - for line in args_json.split("\n"): - print(f" {Colors.DIM}{line}{Colors.RESET}") - - # Execute tool - if function_name not in self.tools: - result = ToolResult( - success=False, - content="", - error=f"Unknown tool: {function_name}", - ) - else: - try: - tool = self.tools[function_name] - result = await tool.execute(**arguments) - except Exception as e: - # Catch all exceptions during tool execution, convert to failed ToolResult - import traceback - - error_detail = f"{type(e).__name__}: {str(e)}" - error_trace = traceback.format_exc() - result = ToolResult( - success=False, - content="", - error=f"Tool execution failed: {error_detail}\n\nTraceback:\n{error_trace}", - ) - - # Log tool execution result - self.logger.log_tool_result( - tool_name=function_name, - arguments=arguments, - result_success=result.success, - result_content=result.content if result.success else None, - result_error=result.error if not result.success else None, + if self.stream: + # Use streaming for real-time token output + result = await self._run_step_stream( + tool_list=tool_list, + step_start_time=step_start_time, + run_start_time=run_start_time, ) - - # Print result - if result.success: - result_text = result.content - if len(result_text) > 300: - result_text = result_text[:300] + f"{Colors.DIM}...{Colors.RESET}" - print(f"{Colors.BRIGHT_GREEN}βœ“ Result:{Colors.RESET} {result_text}") - else: - print(f"{Colors.BRIGHT_RED}βœ— Error:{Colors.RESET} {Colors.RED}{result.error}{Colors.RESET}") - - # Add tool result message - tool_msg = Message( - role="tool", - content=result.content if result.success else f"Error: {result.error}", - tool_call_id=tool_call_id, - name=function_name, + # _run_step_stream handles step increment internally + # Return if task complete, cancelled, or errored + if result is not None: + return result + else: + # Use non-streaming generate() call + result = await self._run_step_nonstream( + tool_list=tool_list, + step_start_time=step_start_time, + run_start_time=run_start_time, ) - self.messages.append(tool_msg) - - # Check for cancellation after each tool execution - if self._check_cancelled(): - self._cleanup_incomplete_messages() - cancel_msg = "Task cancelled by user." - print(f"\n{Colors.BRIGHT_YELLOW}⚠️ {cancel_msg}{Colors.RESET}") - return cancel_msg - - step_elapsed = perf_counter() - step_start_time - total_elapsed = perf_counter() - run_start_time - print(f"\n{Colors.DIM}⏱️ Step {step + 1} completed in {step_elapsed:.2f}s (total: {total_elapsed:.2f}s){Colors.RESET}") - - step += 1 + if result is not None: + return result + step += 1 # Max steps reached error_msg = f"Task couldn't be completed after {self.max_steps} steps." diff --git a/mini_agent/cli.py b/mini_agent/cli.py index f060c9c2..b8fc084d 100644 --- a/mini_agent/cli.py +++ b/mini_agent/cli.py @@ -37,8 +37,14 @@ from mini_agent.tools.mcp_loader import cleanup_mcp_connections, load_mcp_tools_async, set_mcp_timeout_config from mini_agent.tools.note_tool import SessionNoteTool from mini_agent.tools.skill_tool import create_skill_tools +from mini_agent.tools.subagent_tool import SubAgentTool from mini_agent.utils import calculate_display_width +# Force unbuffered stdout for real-time streaming output +# Must be done before any print statements +sys.stdout.reconfigure(line_buffering=False) +sys.stderr.reconfigure(line_buffering=False) + # ANSI color codes class Colors: @@ -313,6 +319,12 @@ def parse_args() -> argparse.Namespace: default=None, help="Execute a task non-interactively and exit", ) + parser.add_argument( + "--no-stream", + action="store_true", + default=False, + help="Disable streaming output (use non-streaming generate)", + ) parser.add_argument( "--version", "-v", @@ -483,7 +495,7 @@ async def _quiet_cleanup(): pass -async def run_agent(workspace_dir: Path, task: str = None): +async def run_agent(workspace_dir: Path, task: str = None, stream: bool = True): """Run Agent in interactive or non-interactive mode. Args: @@ -577,6 +589,12 @@ def on_retry(exception: Exception, attempt: int): # 4. Add workspace-dependent tools add_workspace_tools(tools, config, workspace_dir) + # 4.5. Create SubAgent tool and bind it to the agent after creation + subagent_tool = SubAgentTool( + llm_client=llm_client, + workspace_dir=str(workspace_dir), + ) + # 5. Load System Prompt (with priority search) system_prompt_path = Config.find_config_file(config.agent.system_prompt_path) if system_prompt_path and system_prompt_path.exists(): @@ -601,13 +619,17 @@ def on_retry(exception: Exception, attempt: int): system_prompt = system_prompt.replace("{SKILLS_METADATA}", "") # 7. Create Agent + tools.append(subagent_tool) agent = Agent( llm_client=llm_client, system_prompt=system_prompt, tools=tools, max_steps=config.agent.max_steps, workspace_dir=str(workspace_dir), + stream=stream, ) + # Bind subagent tool to the agent so it can access the parent's tool set + subagent_tool.bind_agent(agent) # 8. Display welcome information if not task: @@ -866,7 +888,7 @@ def main(): workspace_dir.mkdir(parents=True, exist_ok=True) # Run the agent (config always loaded from package directory) - asyncio.run(run_agent(workspace_dir, task=args.task)) + asyncio.run(run_agent(workspace_dir, task=args.task, stream=not args.no_stream)) if __name__ == "__main__": diff --git a/mini_agent/config/system_prompt.md b/mini_agent/config/system_prompt.md index 97d3843d..4a8522e1 100644 --- a/mini_agent/config/system_prompt.md +++ b/mini_agent/config/system_prompt.md @@ -5,6 +5,7 @@ You are Mini-Agent, a versatile AI assistant powered by MiniMax, capable of exec ### 1. **Basic Tools** - **File Operations**: Read, write, edit files with full path support - **Bash Execution**: Run commands, manage git, packages, and system operations +- **SubAgent**: Use the `subagent` tool to spawn child agents for independent subtasks. When a task can be broken into parallel independent parts, call `subagent` for each one rather than doing them sequentially. Pass specific tool names via `tool_names` to delegate tools to the child agent. - **MCP Tools**: Access additional tools from configured MCP servers ### 2. **Specialized Skills** diff --git a/mini_agent/llm/anthropic_client.py b/mini_agent/llm/anthropic_client.py index 6baf9940..0506f3f7 100644 --- a/mini_agent/llm/anthropic_client.py +++ b/mini_agent/llm/anthropic_client.py @@ -1,12 +1,13 @@ """Anthropic LLM client implementation.""" +import json import logging -from typing import Any +from typing import Any, AsyncIterator import anthropic from ..retry import RetryConfig, async_retry -from ..schema import FunctionCall, LLMResponse, Message, TokenUsage, ToolCall +from ..schema import FunctionCall, LLMResponse, Message, StreamChunk, TokenUsage, ToolCall from .base import LLMClientBase logger = logging.getLogger(__name__) @@ -291,3 +292,116 @@ async def generate( # Parse and return response return self._parse_response(response) + + async def generate_stream( + self, + messages: list[Message], + tools: list[Any] | None = None, + ) -> AsyncIterator[StreamChunk]: + """Stream LLM response as async iterator of chunks. + + Args: + messages: List of conversation messages + tools: Optional list of available tools + + Yields: + StreamChunk objects representing partial response + """ + system_message, api_messages = self._convert_messages(messages) + params: dict[str, Any] = { + "model": self.model, + "max_tokens": 16384, + "messages": api_messages, + } + + if system_message: + params["system"] = system_message + if tools: + params["tools"] = self._convert_tools(tools) + + # Buffer for partial tool calls (indexed by block index). + # Note: MiniMax sends: + # - Top-level "text" / "thinking" events for content (complete, not delta) + # - "input_json" events for tool argument streaming + # - "content_block_stop" when a block (including tool_use) ends + tool_call_buffer: dict[int, dict[str, Any]] = {} + + async with self.client.messages.stream(**params) as stream: + async for event in stream: + # Top-level text/thinking β€” complete content, yield immediately + if event.type == "text": + yield StreamChunk(type="content", text=event.text) + continue + elif event.type == "thinking": + yield StreamChunk(type="thinking", text=event.thinking) + continue + elif event.type == "signature": + continue + + # Handle content_block_start β€” buffer setup for tool_use + if hasattr(event, "content_block") and event.type == "content_block_start": + block = event.content_block + idx = event.index + if block.type == "tool_use": + tool_call_buffer[idx] = { + "type": "tool_use", + "id": block.id, + "name": block.name, + "input": "", + } + continue + + # Handle input_json β€” streaming JSON arguments for tool_use blocks + # MiniMax API uses input_json events with partial_json field for args + if event.type == "input_json": + if not tool_call_buffer: + continue + idx = max(tool_call_buffer.keys()) + # partial_json contains the accumulated JSON string so far + partial = getattr(event, "partial_json", "") or "" + if partial: + tool_call_buffer[idx]["input"] = partial + try: + parsed = json.loads(partial) if partial else {} + yield StreamChunk( + type="tool_call_complete", + tool_call=ToolCall( + id=tool_call_buffer[idx]["id"], + type="function", + function=FunctionCall( + name=tool_call_buffer[idx]["name"], + arguments=parsed, + ), + ), + ) + del tool_call_buffer[idx] + except json.JSONDecodeError: + yield StreamChunk( + type="tool_call_delta", + tool_call_id=tool_call_buffer[idx]["id"], + arguments=partial, + ) + continue + + # Handle content_block_stop β€” clean up any remaining incomplete tool calls + if event.type == "content_block_stop": + # Clear buffer for blocks that weren't completed via input_json + # (they should have been completed already, but clean up just in case) + tool_call_buffer.clear() + continue + + # Handle message_delta (final) + elif event.type == "message_delta": + if hasattr(event, "usage") and event.usage: + output_tokens = event.usage.output_tokens or 0 + yield StreamChunk( + type="done", + finish_reason=getattr(event, "stop_reason", None) or "stop", + usage=TokenUsage( + prompt_tokens=0, + completion_tokens=output_tokens, + total_tokens=output_tokens, + ), + ) + elif hasattr(event, "stop_reason"): + yield StreamChunk(type="done", finish_reason=getattr(event, "stop_reason", None) or "stop") diff --git a/mini_agent/llm/base.py b/mini_agent/llm/base.py index 19892a83..7c961290 100644 --- a/mini_agent/llm/base.py +++ b/mini_agent/llm/base.py @@ -1,10 +1,10 @@ """Base class for LLM clients.""" from abc import ABC, abstractmethod -from typing import Any +from typing import Any, AsyncIterator from ..retry import RetryConfig -from ..schema import LLMResponse, Message +from ..schema import LLMResponse, Message, StreamChunk class LLMClientBase(ABC): @@ -54,6 +54,23 @@ async def generate( """ pass + @abstractmethod + async def generate_stream( + self, + messages: list[Message], + tools: list[Any] | None = None, + ) -> AsyncIterator[StreamChunk]: + """Stream LLM response as async iterator of chunks. + + Args: + messages: List of conversation messages + tools: Optional list of available tools + + Yields: + StreamChunk objects representing partial response + """ + pass + @abstractmethod def _prepare_request( self, diff --git a/mini_agent/llm/llm_wrapper.py b/mini_agent/llm/llm_wrapper.py index 28d2c8b7..fea9f766 100644 --- a/mini_agent/llm/llm_wrapper.py +++ b/mini_agent/llm/llm_wrapper.py @@ -5,9 +5,10 @@ """ import logging +from typing import AsyncIterator from ..retry import RetryConfig -from ..schema import LLMProvider, LLMResponse, Message +from ..schema import LLMProvider, LLMResponse, Message, StreamChunk from .anthropic_client import AnthropicClient from .base import LLMClientBase from .openai_client import OpenAIClient @@ -125,3 +126,19 @@ async def generate( LLMResponse containing the generated content """ return await self._client.generate(messages, tools) + + async def generate_stream( + self, + messages: list[Message], + tools: list | None = None, + ) -> AsyncIterator[StreamChunk]: + """Stream LLM response as async iterator of chunks. + + Args: + messages: List of conversation messages + tools: Optional list of Tool objects or dicts + + Yields: + StreamChunk objects representing partial response + """ + return self._client.generate_stream(messages, tools) diff --git a/mini_agent/llm/openai_client.py b/mini_agent/llm/openai_client.py index a30fc197..63792656 100644 --- a/mini_agent/llm/openai_client.py +++ b/mini_agent/llm/openai_client.py @@ -2,12 +2,12 @@ import json import logging -from typing import Any +from typing import Any, AsyncIterator from openai import AsyncOpenAI from ..retry import RetryConfig, async_retry -from ..schema import FunctionCall, LLMResponse, Message, TokenUsage, ToolCall +from ..schema import FunctionCall, LLMResponse, Message, StreamChunk, TokenUsage, ToolCall from .base import LLMClientBase logger = logging.getLogger(__name__) @@ -293,3 +293,122 @@ async def generate( # Parse and return response return self._parse_response(response) + + async def generate_stream( + self, + messages: list[Message], + tools: list[Any] | None = None, + ) -> AsyncIterator[StreamChunk]: + """Stream LLM response as async iterator of chunks. + + Args: + messages: List of conversation messages + tools: Optional list of available tools + + Yields: + StreamChunk objects representing partial response + """ + _, api_messages = self._convert_messages(messages) + params: dict[str, Any] = { + "model": self.model, + "messages": api_messages, + "extra_body": {"reasoning_split": True}, + "stream": True, + } + + if tools: + params["tools"] = self._convert_tools(tools) + + # Buffer for partial tool calls indexed by tool_call index + tool_call_buffer: dict[int, dict[str, Any]] = {} + + # Accumulate usage across all chunks + total_usage = None + + stream = await self.client.chat.completions.create(**params) + async for event in stream: + chunk = event.choices[0] + + # Accumulate usage if present + if hasattr(event, "usage") and event.usage: + total_usage = TokenUsage( + prompt_tokens=event.usage.prompt_tokens or 0, + completion_tokens=event.usage.completion_tokens or 0, + total_tokens=event.usage.total_tokens or 0, + ) + + delta = chunk.delta + + # Check for content (text or thinking) + if hasattr(delta, "content") and delta.content: + # Check if this is thinking content + # MiniMax uses reasoning_details for thinking + if hasattr(delta, "reasoning_details") and delta.reasoning_details: + for rd in delta.reasoning_details: + if hasattr(rd, "text") and rd.text: + yield StreamChunk(type="thinking", text=rd.text) + # Regular content + yield StreamChunk(type="content", text=delta.content) + + # Check for reasoning/thinking content specifically + if hasattr(delta, "reasoning_details") and delta.reasoning_details: + for rd in delta.reasoning_details: + if hasattr(rd, "text") and rd.text: + yield StreamChunk(type="thinking", text=rd.text) + + # Check for tool calls + if hasattr(delta, "tool_calls") and delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_call_buffer: + tool_call_buffer[idx] = { + "id": "", + "name": "", + "arguments": "", + } + + if tc.id: + tool_call_buffer[idx]["id"] = tc.id + if hasattr(tc, "function"): + if tc.function.name: + tool_call_buffer[idx]["name"] = tc.function.name + if tc.function.arguments: + tool_call_buffer[idx]["arguments"] += tc.function.arguments + + # Check if this tool call is complete (has id and name and empty arguments suffix) + # OpenAI streams arguments as partial JSON, we check for completion + # by seeing if arguments are complete (ends with appropriate JSON terminator) + args_str = tool_call_buffer[idx]["arguments"] + if tool_call_buffer[idx]["id"] and tool_call_buffer[idx]["name"] and args_str: + # Try to parse as JSON to see if complete + try: + json.loads(args_str) + # Successfully parsed β€” tool call is complete + yield StreamChunk( + type="tool_call_complete", + tool_call=ToolCall( + id=tool_call_buffer[idx]["id"], + type="function", + function=FunctionCall( + name=tool_call_buffer[idx]["name"], + arguments=json.loads(args_str), + ), + ), + ) + del tool_call_buffer[idx] + except json.JSONDecodeError: + # Incomplete JSON β€” emit delta + yield StreamChunk( + type="tool_call_delta", + tool_call_id=tool_call_buffer[idx]["id"], + arguments=args_str, + ) + + # Check for completion + finish = getattr(chunk, "finish_reason", None) + if finish: + yield StreamChunk( + type="done", + finish_reason=finish, + usage=total_usage, + ) diff --git a/mini_agent/schema/__init__.py b/mini_agent/schema/__init__.py index e4dc1f01..66ab4b51 100644 --- a/mini_agent/schema/__init__.py +++ b/mini_agent/schema/__init__.py @@ -5,6 +5,7 @@ LLMProvider, LLMResponse, Message, + StreamChunk, TokenUsage, ToolCall, ) @@ -14,6 +15,7 @@ "LLMProvider", "LLMResponse", "Message", + "StreamChunk", "TokenUsage", "ToolCall", ] diff --git a/mini_agent/schema/schema.py b/mini_agent/schema/schema.py index 4bffb442..1528f560 100644 --- a/mini_agent/schema/schema.py +++ b/mini_agent/schema/schema.py @@ -53,3 +53,16 @@ class LLMResponse(BaseModel): tool_calls: list[ToolCall] | None = None finish_reason: str usage: TokenUsage | None = None # Token usage from API response + + +class StreamChunk(BaseModel): + """A single chunk from a streaming LLM response.""" + + type: str # "thinking" | "content" | "tool_call_start" | "tool_call_delta" | "tool_call_complete" | "done" + text: str | None = None # For thinking / content chunks + tool_call_id: str | None = None # For tool call chunks + tool_name: str | None = None # For tool_call_start + arguments: str | None = None # For tool_call_delta (partial JSON string fragment) + tool_call: ToolCall | None = None # For tool_call_complete (full tool call) + finish_reason: str | None = None # For done + usage: TokenUsage | None = None # For done diff --git a/mini_agent/tools/__init__.py b/mini_agent/tools/__init__.py index 4db00500..11b537a8 100644 --- a/mini_agent/tools/__init__.py +++ b/mini_agent/tools/__init__.py @@ -4,6 +4,7 @@ from .bash_tool import BashTool from .file_tools import EditTool, ReadTool, WriteTool from .note_tool import RecallNoteTool, SessionNoteTool +from .subagent_tool import SubAgentTool __all__ = [ "Tool", @@ -14,4 +15,5 @@ "BashTool", "SessionNoteTool", "RecallNoteTool", + "SubAgentTool", ] diff --git a/mini_agent/tools/subagent_tool.py b/mini_agent/tools/subagent_tool.py new file mode 100644 index 00000000..e8ff5bb3 --- /dev/null +++ b/mini_agent/tools/subagent_tool.py @@ -0,0 +1,174 @@ +"""SubAgent Tool - Spawn a child agent to handle a subtask independently.""" + +import asyncio +from pathlib import Path +from typing import Any, Callable, Optional + +from .base import Tool, ToolResult + + +class SubAgentTool(Tool): + """Tool for spawning a child agent to handle a subtask. + + The child agent operates in an isolated message history and can use + a subset of the parent's tools. Results are returned as a ToolResult + once the child completes. + + Use this when a task can be broken into independent parts that benefit + from separate reasoning chains, or when you want to delegate a complex + subtask to a fresh agent context. + + Example: + subagent( + task="Research the history of the Roman Empire", + tool_names=["bash", "read", "write"], + max_steps=30 + ) + """ + + def __init__( + self, + llm_client: "LLMClient", + system_prompt: str = "You are a helpful assistant focused on completing tasks thoroughly.", + default_max_steps: int = 20, + workspace_dir: str = "./workspace", + ): + """Initialize SubAgentTool. + + Args: + llm_client: LLMClient instance to use for the child agent. + system_prompt: Default system prompt for child agents. + default_max_steps: Default max steps for child agents if not specified. + workspace_dir: Workspace directory for child agents. + """ + self._llm = llm_client + self._system_prompt = system_prompt + self._default_max_steps = default_max_steps + self._workspace_dir = workspace_dir + # The agent reference is set after the tool is registered with an agent. + # This uses a MutableContainer so the reference can be injected post-creation. + self._agent_ref: Optional["Agent"] = None + + def bind_agent(self, agent: "Agent") -> None: + """Bind this tool to an agent instance. + + Must be called after the agent is created and before the tool is used. + Allows the tool to access the agent's current tool set at execute time. + """ + self._agent_ref = agent + + @property + def name(self) -> str: + return "subagent" + + @property + def description(self) -> str: + return ( + "Spawn a child agent to handle a subtask independently. " + "Use when a task can be broken into parallel or independent parts, " + "or when you want a fresh reasoning context for a complex subtask. " + "The child agent has its own isolated message history and optionally " + "a restricted set of tools. Returns the child's final response." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "task": { + "type": "string", + "description": ( + "The task, question, or instruction to give to the child agent. " + "Be specific about what you need." + ), + }, + "system_prompt": { + "type": "string", + "description": ( + "Optional system prompt override for this child agent. " + "Use this to give the child agent specific instructions or context." + ), + }, + "tool_names": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "List of tool names to delegate to the child agent. " + "If not provided, the child agent has no external tools. " + "Available tools in the parent: bash, read, write, edit, " + "session_note, recall_notes, and any loaded MCP/skill tools." + ), + }, + "max_steps": { + "type": "integer", + "description": ( + f"Maximum steps for the child agent (default: {self._default_max_steps}). " + "Increase for complex tasks that need more reasoning steps." + ), + }, + }, + "required": ["task"], + } + + async def execute( + self, + task: str, + system_prompt: Optional[str] = None, + tool_names: Optional[list[str]] = None, + max_steps: Optional[int] = None, + ) -> ToolResult: + from mini_agent.agent import Agent + + if self._agent_ref is None: + return ToolResult( + success=False, + content="", + error="SubAgentTool: agent not bound. Call bind_agent(agent) first.", + ) + + # Resolve tools from the parent's current tool set + available = list(self._agent_ref.tools.values()) + available_names = {t.name for t in available} + + if tool_names is not None: + child_tool_list = [t for t in available if t.name in tool_names] + missing = set(tool_names) - available_names + if missing: + return ToolResult( + success=False, + content="", + error=f"Tool(s) not available in parent agent: {', '.join(sorted(missing))}. " + f"Available: {', '.join(sorted(available_names))}", + ) + else: + child_tool_list = [] + + # Build system prompt + prompt = system_prompt or self._system_prompt + if "Current Workspace" not in prompt: + workspace = Path(self._workspace_dir).resolve() + prompt = ( + f"{prompt}\n\n" + f"## Current Workspace\n" + f"You are working in: `{workspace}`\n" + f"All relative paths are resolved relative to this directory." + ) + + # Create child agent with isolated history + child = Agent( + llm_client=self._llm, + system_prompt=prompt, + tools=child_tool_list, + max_steps=max_steps or self._default_max_steps, + workspace_dir=str(Path(self._workspace_dir).resolve()), + stream=False, + ) + child.add_user_message(task) + + # Run child agent + try: + result = await child.run() + return ToolResult(success=True, content=result) + except Exception as exc: + return ToolResult(success=False, content="", error=f"SubAgent error: {exc}") diff --git a/minimax-skills b/minimax-skills new file mode 160000 index 00000000..77f30690 --- /dev/null +++ b/minimax-skills @@ -0,0 +1 @@ +Subproject commit 77f306906afe584a03751b959e477b5a125fb31f diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 00000000..ac4c0fca --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,400 @@ +"""Tests for LLM streaming functionality.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from mini_agent.llm.anthropic_client import AnthropicClient +from mini_agent.llm.openai_client import OpenAIClient +from mini_agent.llm.llm_wrapper import LLMClient +from mini_agent.schema import ( + FunctionCall, + LLMProvider, + LLMResponse, + Message, + StreamChunk, + ToolCall, + TokenUsage, +) + + +class TestStreamChunk: + """Test StreamChunk schema.""" + + def test_stream_chunk_content(self): + """Test content chunk creation.""" + chunk = StreamChunk(type="content", text="Hello") + assert chunk.type == "content" + assert chunk.text == "Hello" + + def test_stream_chunk_thinking(self): + """Test thinking chunk creation.""" + chunk = StreamChunk(type="thinking", text="Let me think...") + assert chunk.type == "thinking" + assert chunk.text == "Let me think..." + + def test_stream_chunk_tool_call_delta(self): + """Test tool call delta chunk.""" + chunk = StreamChunk( + type="tool_call_delta", + tool_call_id="abc123", + arguments='{"name":', + ) + assert chunk.type == "tool_call_delta" + assert chunk.tool_call_id == "abc123" + assert chunk.arguments == '{"name":' + + def test_stream_chunk_tool_call_complete(self): + """Test complete tool call chunk.""" + tool_call = ToolCall( + id="abc123", + type="function", + function=FunctionCall(name="test_tool", arguments={"arg": "value"}), + ) + chunk = StreamChunk(type="tool_call_complete", tool_call=tool_call) + assert chunk.type == "tool_call_complete" + assert chunk.tool_call.id == "abc123" + assert chunk.tool_call.function.name == "test_tool" + + def test_stream_chunk_done(self): + """Test done chunk.""" + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + chunk = StreamChunk(type="done", finish_reason="stop", usage=usage) + assert chunk.type == "done" + assert chunk.finish_reason == "stop" + assert chunk.usage.total_tokens == 150 + + +class TestOpenAIStreaming: + """Test OpenAI client streaming.""" + + @pytest.fixture + def openai_client(self): + """Create OpenAI client for testing.""" + client = OpenAIClient( + api_key="test-key", + api_base="https://test.api.minimaxi.com/v1", + model="MiniMax-M2.5", + retry_config=None, + ) + return client + + @pytest.mark.asyncio + async def test_generate_stream_content_only(self, openai_client): + """Test streaming content without tool calls.""" + # Create a proper async iterator + class MockEvent: + """Single mock streaming event.""" + def __init__(self, content_text, finish=None): + class MockDelta: + content = content_text + reasoning_details = None + class MockChoice: + delta = MockDelta() + finish_reason = finish + self.choices = [MockChoice()] + self.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + class AsyncEventStream: + """Mock async iterator for OpenAI streaming.""" + def __init__(self, events): + self.events = events + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.events): + raise StopAsyncIteration + event = self.events[self.index] + self.index += 1 + return event + + events = [ + MockEvent("Hello"), + MockEvent(", world!"), + MockEvent("", finish="stop"), + ] + + async def mock_create(**kwargs): + return AsyncEventStream(events) + + openai_client.client = MagicMock() + openai_client.client.chat.completions.create = mock_create + + messages = [Message(role="user", content="Say hello")] + chunks = [] + async for chunk in openai_client.generate_stream(messages): + chunks.append(chunk) + + assert len(chunks) >= 2 + content_chunks = [c for c in chunks if c.type == "content"] + assert len(content_chunks) >= 2 + + @pytest.mark.asyncio + async def test_generate_stream_with_thinking(self, openai_client): + """Test streaming with thinking content.""" + class MockThinkingEvent: + def __init__(self, think_text): + class MockDelta: + content = None + reasoning_details = [MagicMock(text=think_text)] + class MockChoice: + delta = MockDelta() + finish_reason = None + self.choices = [MockChoice()] + self.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + class MockContentEvent: + def __init__(self, content_text): + class MockDelta: + content = content_text + reasoning_details = None + class MockChoice: + delta = MockDelta() + finish_reason = None + self.choices = [MockChoice()] + self.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + class AsyncEventStream: + def __init__(self, events): + self.events = events + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.events): + raise StopAsyncIteration + event = self.events[self.index] + self.index += 1 + return event + + events = [ + MockThinkingEvent("Let me think..."), + MockContentEvent("Here's my answer"), + ] + + async def mock_create(**kwargs): + return AsyncEventStream(events) + + openai_client.client = MagicMock() + openai_client.client.chat.completions.create = mock_create + + messages = [Message(role="user", content="Think about something")] + chunks = [] + async for chunk in openai_client.generate_stream(messages): + chunks.append(chunk) + + thinking_chunks = [c for c in chunks if c.type == "thinking"] + assert len(thinking_chunks) >= 1 + + @pytest.mark.asyncio + async def test_generate_stream_tool_call_complete(self, openai_client): + """Test that tool call is emitted when JSON is complete.""" + class MockToolEvent: + """Mock event with tool call.""" + def __init__(self, args_str, has_id=False, has_name=False, finish=None): + class MockFunction: + name = "test_tool" if has_name else "" + arguments = args_str + + class MockToolCall: + index = 0 + id = "call_123" if has_id else "" + function = MockFunction() + + class MockDelta: + content = None + reasoning_details = None + tool_calls = [MockToolCall()] if args_str else None + + class MockChoice: + delta = MockDelta() + finish_reason = finish + + self.choices = [MockChoice()] + self.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + class AsyncEventStream: + def __init__(self, events): + self.events = events + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.events): + raise StopAsyncIteration + event = self.events[self.index] + self.index += 1 + return event + + # Tool call with complete arguments + events = [ + MockToolEvent('{"arg": "value"}', has_id=True, has_name=True, finish="tool_calls"), + ] + + async def mock_create(**kwargs): + return AsyncEventStream(events) + + openai_client.client = MagicMock() + openai_client.client.chat.completions.create = mock_create + + messages = [Message(role="user", content="Use test tool")] + chunks = [] + async for chunk in openai_client.generate_stream(messages): + chunks.append(chunk) + + # Should have a complete tool call + complete_chunks = [c for c in chunks if c.type == "tool_call_complete"] + assert len(complete_chunks) == 1 + assert complete_chunks[0].tool_call.function.name == "test_tool" + assert complete_chunks[0].tool_call.function.arguments == {"arg": "value"} + + +class TestAnthropicStreaming: + """Test Anthropic client streaming.""" + + @pytest.fixture + def anthropic_client(self): + """Create Anthropic client for testing.""" + client = AnthropicClient( + api_key="test-key", + api_base="https://test.api.minimaxi.com/anthropic", + model="MiniMax-M2.5", + retry_config=None, + ) + return client + + @pytest.mark.asyncio + async def test_generate_stream_content_only(self, anthropic_client): + """Test streaming content without tool calls.""" + # Real SDK sends type="text" as top-level event for text content + class MockEvent: + def __init__(self, text): + self.type = "text" + self.text = text + + class AsyncEventStream: + def __init__(self, events): + self.events = events + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.events): + raise StopAsyncIteration + event = self.events[self.index] + self.index += 1 + return event + + events = [ + MockEvent("Hello"), + MockEvent(", world!"), + ] + + mock_client = AsyncMock() + mock_stream_ctx = MagicMock() + mock_stream_ctx.__aenter__ = AsyncMock(return_value=AsyncEventStream(events)) + mock_stream_ctx.__aexit__ = AsyncMock() + mock_client.messages.stream = MagicMock(return_value=mock_stream_ctx) + anthropic_client.client = mock_client + + messages = [Message(role="user", content="Say hello")] + chunks = [] + async for chunk in anthropic_client.generate_stream(messages): + chunks.append(chunk) + + content_chunks = [c for c in chunks if c.type == "content"] + assert len(content_chunks) >= 2 + + @pytest.mark.asyncio + async def test_generate_stream_thinking(self, anthropic_client): + """Test streaming with thinking content.""" + # Real SDK sends type="thinking" as top-level event + class MockEvent: + def __init__(self, think_text): + self.type = "thinking" + self.thinking = think_text + + class AsyncEventStream: + def __init__(self, events): + self.events = events + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.events): + raise StopAsyncIteration + event = self.events[self.index] + self.index += 1 + return event + + events = [ + MockEvent("Let me think..."), + ] + + mock_client = AsyncMock() + mock_stream_ctx = MagicMock() + mock_stream_ctx.__aenter__ = AsyncMock(return_value=AsyncEventStream(events)) + mock_stream_ctx.__aexit__ = AsyncMock() + mock_client.messages.stream = MagicMock(return_value=mock_stream_ctx) + anthropic_client.client = mock_client + + messages = [Message(role="user", content="Think about something")] + chunks = [] + async for chunk in anthropic_client.generate_stream(messages): + chunks.append(chunk) + + thinking_chunks = [c for c in chunks if c.type == "thinking"] + assert len(thinking_chunks) >= 1 + + +class TestLLMClientWrapperStreaming: + """Test LLMClient wrapper streaming interface.""" + + @pytest.mark.asyncio + async def test_generate_stream_delegates_to_client(self): + """Test that generate_stream properly delegates to underlying client.""" + client = LLMClient( + api_key="test-key", + provider=LLMProvider.OPENAI, + model="MiniMax-M2.5", + ) + + # Verify generate_stream method exists and is accessible + assert hasattr(client, "generate_stream") + assert callable(client.generate_stream) + + +class TestMessageConversion: + """Test message conversion for streaming.""" + + def test_convert_thinking_in_assistant_message(self): + """Test that thinking is preserved in assistant messages for streaming.""" + msg = Message( + role="assistant", + content="Here's my response", + thinking="Let me think about this...", + tool_calls=None, + ) + + assert msg.thinking == "Let me think about this..." + assert msg.content == "Here's my response" + + +# Helper for async iterator mock +async def async_iter(items): + """Create async iterator from list of items.""" + for item in items: + yield item diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py new file mode 100644 index 00000000..8e6afb08 --- /dev/null +++ b/tests/test_subagent_tool.py @@ -0,0 +1,88 @@ +"""Test cases for SubAgentTool.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from mini_agent.tools.subagent_tool import SubAgentTool + + +@pytest.mark.asyncio +async def test_subagent_tool_basic_properties(): + """Test SubAgentTool name, description, parameters.""" + tool = SubAgentTool( + llm_client=MagicMock(), + workspace_dir="/tmp", + ) + assert tool.name == "subagent" + assert "spawn" in tool.description.lower() + assert "task" in tool.parameters["properties"] + assert tool.parameters["required"] == ["task"] + + +@pytest.mark.asyncio +async def test_subagent_tool_unbound_agent_error(): + """Test that execute fails cleanly when agent is not bound.""" + tool = SubAgentTool( + llm_client=MagicMock(), + workspace_dir="/tmp", + ) + result = await tool.execute(task="hello") + assert not result.success + assert "not bound" in result.error + + +@pytest.mark.asyncio +async def test_subagent_tool_missing_tools_error(): + """Test error when requested tool is not in parent agent's tools.""" + mock_llm = MagicMock() + tool = SubAgentTool( + llm_client=mock_llm, + workspace_dir="/tmp", + ) + + # Mock agent with some tools + mock_tool = MagicMock() + mock_tool.name = "bash" + mock_agent = MagicMock() + mock_agent.tools = {"bash": mock_tool} + tool.bind_agent(mock_agent) + + # Try to delegate a non-existent tool + result = await tool.execute( + task="test task", + tool_names=["nonexistent_tool"], + ) + assert not result.success + assert "not available" in result.error + + +@pytest.mark.asyncio +async def test_subagent_tool_delegates_valid_tools(): + """Test that valid tool_names filter works correctly.""" + mock_llm = MagicMock() + tool = SubAgentTool( + llm_client=mock_llm, + workspace_dir="/tmp", + ) + + # Create mock tools + bash_tool = MagicMock() + bash_tool.name = "bash" + read_tool = MagicMock() + read_tool.name = "read" + + mock_agent = MagicMock() + mock_agent.tools = {"bash": bash_tool, "read": read_tool} + tool.bind_agent(mock_agent) + + # This would create child with only bash - we can't fully test execute without + # a real LLM, but we can verify the filter logic ran (child agent was created) + result = await tool.execute( + task="test", + tool_names=["bash"], # valid - filter passed + ) + # The mock LLM can't make real calls, so child.run() fails internally. + # But we can confirm: no "not available" error β†’ tool delegation worked + assert "not available" not in (result.error or result.content or "")