diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/__init__.py new file mode 100644 index 00000000..9a6da0f7 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/__init__.py @@ -0,0 +1,5 @@ +"""Gateway Strands plugins.""" + +from .agentcore_tool_search import AgentCoreToolSearchPlugin + +__all__ = ["AgentCoreToolSearchPlugin"] diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/README.md b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/README.md new file mode 100644 index 00000000..8cc5c26a --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/README.md @@ -0,0 +1,122 @@ +# Strands AgentCore Tool Search Plugin + +A semantic tool discovery plugin for [Strands Agents](https://github.com/strands-agents/sdk-python) that uses the [Amazon Bedrock AgentCore Gateway](https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/gateway-using-mcp-semantic-search.html) `x_amz_bedrock_agentcore_search` tool. This enables agents to dynamically load only the relevant tools for each invocation by deriving user intent from conversation history, even when hundreds of tools are registered on the gateway. + +## Features + +- **Semantic tool discovery** — uses AgentCore Gateway's built-in search to find relevant tools +- **Intent-based loading** — derives user intent via LLM before searching +- **No list_tools call** — tools are built directly from search results +- **Pluggable intent provider** — swap the default intent provider with your own +- **Agent model reuse** — by default, the intent classifier uses the same model as the parent agent + +## Installation + +```bash +pip install agentcore-tool-search-plugin +``` + +## Usage + +```python +from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client +from strands import Agent +from strands.tools.mcp import MCPClient +from agentcore_tool_search_plugin import AgentCoreToolSearchPlugin + +mcp_client = MCPClient(lambda: aws_iam_streamablehttp_client( + endpoint="https://.gateway.bedrock-agentcore..amazonaws.com/mcp", + aws_region="us-east-1", + aws_service="bedrock-agentcore", +)) + +mcp_client.start() + +agent = Agent(plugins=[AgentCoreToolSearchPlugin(mcp_client=mcp_client)]) + +agent("Find me afternoon flights to New York") +``` + +Or using a context manager: + +```python +with mcp_client: + agent = Agent(plugins=[AgentCoreToolSearchPlugin(mcp_client=mcp_client)]) + agent("Find me afternoon flights to New York") +``` + +## How It Works + +![Tool Search Flow](images/agentcore_tool_search_plugin.png) + +On each agent invocation: + +1. **User query** — The user sends a query to Strands agent. +2. **Hook** — The agent triggers the `AgentCoreToolSearchPlugin` before model invocation +3. **Derive intent** — The `IntentProvider` sends the last N messages from conversation history to the configured LLM to produce a concise intent string +4. **Search gateway** — The intent is passed to AgentCore Gateway's `x_amz_bedrock_agentcore_search` tool to obtain most relevant tools. +5. **Invoke LLM** — The agent invokes the LLM with the user query along with the matched tools from registered MCP targets (Lambda, API Gateway, MCP Server) + +Previously loaded tools are cleared before each search, so the agent always has the most relevant tools available. + +## Intent Provider + +An `IntentProvider` is responsible for analyzing conversation messages and producing a concise intent string that drives tool search. The plugin calls `derive_intent(messages, model)` before each invocation to determine what tools to load. + +### Default Intent Provider + +`DefaultIntentProvider` uses an LLM to classify the last few conversation messages into a concise intent string. By default it uses the agent's model. + +**Basic usage (uses the agent's model automatically):** + +```python +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin + +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin(mcp_client=mcp_client) +]) +``` + +**With a custom model for intent classification:** + +```python +from strands.models.bedrock import BedrockModel +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import DefaultIntentProvider + +intent_model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0") +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin( + mcp_client=mcp_client, + intent_provider=DefaultIntentProvider(model=intent_model), + ) +]) +``` + +### Custom Intent Provider + +You can provide your own intent derivation strategy by subclassing `IntentProvider`: + +```python +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import IntentProvider + +class MyIntentProvider(IntentProvider): + def derive_intent(self, messages: list[dict], model=None) -> str: + # custom logic to derive intent + return "intent string" + +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin( + mcp_client=mcp_client, + intent_provider=MyIntentProvider(), + ) +]) +``` + +## Prerequisites + +- An AgentCore Gateway with **[semantic search](https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/gateway-using-mcp-semantic-search.html) enabled** +- Tools registered on the gateway with descriptions +- AWS credentials with access to the gateway + +For more details, see the [AgentCore Gateway Documentation](https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/gateway-building.html). diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/__init__.py new file mode 100644 index 00000000..f9d8692f --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/__init__.py @@ -0,0 +1,6 @@ +"""AgentCore Tool Search plugin for Strands Agents.""" + +from .intent_providers import DefaultIntentProvider, IntentProvider +from .plugin import AgentCoreToolSearchPlugin + +__all__ = ["AgentCoreToolSearchPlugin", "IntentProvider", "DefaultIntentProvider"] diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/images/agentcore_tool_search_plugin.png b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/images/agentcore_tool_search_plugin.png new file mode 100644 index 00000000..4e29dc45 Binary files /dev/null and b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/images/agentcore_tool_search_plugin.png differ diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/__init__.py new file mode 100644 index 00000000..9c849c0e --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/__init__.py @@ -0,0 +1,6 @@ +"""Intent provider interfaces and implementations.""" + +from .default_intent_provider import DefaultIntentProvider +from .intent_provider import IntentProvider + +__all__ = ["DefaultIntentProvider", "IntentProvider"] diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/default_intent_provider.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/default_intent_provider.py new file mode 100644 index 00000000..859b33ee --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/default_intent_provider.py @@ -0,0 +1,64 @@ +"""Default LLM-based intent provider implementation.""" + +import logging + +from strands import Agent + +from .intent_provider import IntentProvider + +logger = logging.getLogger(__name__) + +INTENT_SYSTEM_PROMPT = ( + "You are an intent classifier. Given the recent conversation messages, " + "produce a concise one-sentence description of what the user is trying to accomplish. " + "Focus on the type of task, not the specific details. " + "Reply with ONLY the intent description, nothing else." +) + + +class DefaultIntentProvider(IntentProvider): + """LLM-based intent provider that classifies the last N messages.""" + + def __init__(self, message_window: int = 5, model=None): + """Initialize DefaultIntentProvider. + + Args: + message_window: Number of recent messages to consider. + model: Optional explicit model for intent classification. + """ + self._message_window = message_window + self._explicit_model = model + + def derive_intent(self, messages: list[dict], model=None) -> str: + """Derive intent using an LLM. Falls back to agent's model if no explicit model set.""" + try: + recent_messages = messages[-self._message_window :] if messages else [] + if not recent_messages: + return "" + + kwargs = {"system_prompt": INTENT_SYSTEM_PROMPT, "tools": []} + # Priority: explicit model > agent's model > Strands default + resolved_model = self._explicit_model or model + if resolved_model: + kwargs["model"] = resolved_model + + intent_agent = Agent(**kwargs) + response = intent_agent(self._format_messages_for_prompt(recent_messages)) + return str(response).strip() + except Exception as e: + logger.error("Failed to derive intent: %s", e) + return "" + + def _format_messages_for_prompt(self, messages: list[dict]) -> str: + """Format messages into a text prompt for the intent LLM.""" + parts = [] + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", []) + text = "" + if isinstance(content, list): + text = " ".join( + block.get("text", "") for block in content if isinstance(block, dict) and "text" in block + ) + parts.append(f"{role}: {text}") + return "\n".join(parts) diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/intent_provider.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/intent_provider.py new file mode 100644 index 00000000..864ece94 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/intent_provider.py @@ -0,0 +1,26 @@ +"""Intent provider abstract interface.""" + +from abc import ABC, abstractmethod + + +class IntentProvider(ABC): + """Abstract interface for deriving user intent from conversation messages. + + Subclasses must implement the `derive_intent` method to analyze conversation + messages and return a concise intent string. + """ + + @abstractmethod + def derive_intent(self, messages: list[dict], model=None) -> str: + """Analyze conversation messages and return a concise intent string. + + Args: + messages: List of conversation message dicts in Strands format. + model: Optional model instance from the parent agent. Implementations + can use this for LLM-based intent derivation. + + Returns: + A plain text string describing the user's intent. + Returns empty string if intent cannot be determined. + """ + ... diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/plugin.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/plugin.py new file mode 100644 index 00000000..d2a0bd00 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/plugin.py @@ -0,0 +1,116 @@ +"""AgentCore tool search plugin for Strands Agents.""" + +import json +import logging + +from mcp.types import Tool as MCPTool +from strands.hooks import BeforeInvocationEvent +from strands.plugins import Plugin, hook +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool + +from .intent_providers import DefaultIntentProvider, IntentProvider + +logger = logging.getLogger(__name__) + + +class AgentCoreToolSearchPlugin(Plugin): + """Plugin that dynamically loads tools from AgentCore Gateway based on semantic intent. + + Args: + mcp_client: MCPClient connected to an AgentCore Gateway. + intent_provider: Strategy for deriving intent. Defaults to DefaultIntentProvider. + """ + + name = "agentcore-tool-search-plugin" + + def __init__( + self, + mcp_client: MCPClient, + intent_provider: IntentProvider | None = None, + ): + """Initialize the plugin. + + Args: + mcp_client: MCPClient connected to an AgentCore Gateway. + intent_provider: Strategy for deriving intent. Defaults to DefaultIntentProvider. + """ + super().__init__() + self._intent_provider = intent_provider or DefaultIntentProvider() + self._mcp_client = mcp_client + self._loaded_tool_names: set[str] = set() + + @property + def tools(self): + """Return empty list; tools are loaded dynamically via the hook.""" + return [] + + @hook + def on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Derive intent, search gateway, and load matching tools.""" + messages = event.messages or [] + + # Pass the agent's model to the intent provider + intent = self._intent_provider.derive_intent(messages, model=event.agent.model) + logger.info("Derived intent: %s", intent) + + # Clear all previously loaded conditional tools + for name in list(self._loaded_tool_names): + event.agent.tool_registry.registry.pop(name, None) + self._loaded_tool_names.clear() + + if not intent: + return + + try: + result = self._mcp_client.call_tool_sync( + tool_use_id="intent-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": intent}, + ) + agent_tools = self._build_tools_from_search_result(result) + except Exception as e: + logger.error("AgentCore Gateway search failed: %s", e) + return + + for agent_tool in agent_tools: + try: + event.agent.tool_registry.register_tool(agent_tool) + self._loaded_tool_names.add(agent_tool.tool_name) + except Exception as e: + logger.error("Failed to register tool %s: %s", agent_tool.tool_name, e) + + logger.info("Loaded tools: %s", self._loaded_tool_names) + + def _build_tools_from_search_result(self, result) -> list[MCPAgentTool]: + """Build MCPAgentTool objects from the gateway search response.""" + tools = [] + if not result or not isinstance(result, dict): + return tools + + tool_defs = [] + structured = result.get("structuredContent") + if isinstance(structured, dict) and "tools" in structured: + tool_defs = structured["tools"] + else: + for block in result.get("content", []): + if isinstance(block, dict) and "text" in block: + try: + data = json.loads(block["text"]) + if isinstance(data, dict) and "tools" in data: + tool_defs = data["tools"] + break + except (json.JSONDecodeError, TypeError): + continue + + for tool_def in tool_defs: + if not isinstance(tool_def, dict) or "name" not in tool_def: + continue + mcp_tool = MCPTool( + name=tool_def["name"], + description=tool_def.get("description", ""), + inputSchema=tool_def.get("inputSchema", {"type": "object", "properties": {}}), + ) + tools.append(MCPAgentTool(mcp_tool=mcp_tool, mcp_client=self._mcp_client)) + + return tools diff --git a/tests/bedrock_agentcore/gateway/__init__.py b/tests/bedrock_agentcore/gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/gateway/integrations/__init__.py b/tests/bedrock_agentcore/gateway/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/gateway/integrations/strands/__init__.py b/tests/bedrock_agentcore/gateway/integrations/strands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/gateway/integrations/strands/test_agentcore_tool_search_plugin.py b/tests/bedrock_agentcore/gateway/integrations/strands/test_agentcore_tool_search_plugin.py new file mode 100644 index 00000000..d91373d2 --- /dev/null +++ b/tests/bedrock_agentcore/gateway/integrations/strands/test_agentcore_tool_search_plugin.py @@ -0,0 +1,240 @@ +"""Tests for AgentCoreToolSearchPlugin.""" + +import json +from unittest.mock import Mock + +import pytest + +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import ( + DefaultIntentProvider, + IntentProvider, +) +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.plugin import ( + AgentCoreToolSearchPlugin, +) + + +class FakeIntentProvider(IntentProvider): + """Test intent provider that returns a fixed intent string.""" + + def __init__(self, intent: str = "test intent"): + self._intent = intent + + def derive_intent(self, messages: list[dict], model=None) -> str: + return self._intent + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCPClient.""" + client = Mock() + client.call_tool_sync.return_value = {"content": []} + return client + + +@pytest.fixture +def fixed_intent_provider(): + """Create a fixed intent provider.""" + return FakeIntentProvider("get weather") + + +@pytest.fixture +def plugin(mock_mcp_client, fixed_intent_provider): + """Create an AgentCoreToolSearchPlugin with mocked dependencies.""" + return AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=fixed_intent_provider) + + +@pytest.fixture +def mock_event(): + """Create a mock BeforeInvocationEvent.""" + event = Mock() + event.messages = [{"role": "user", "content": [{"text": "hello"}]}] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + return event + + +class TestAgentCoreToolSearchPluginInit: + """Test AgentCoreToolSearchPlugin initialization.""" + + def test_init_with_custom_intent_provider(self, mock_mcp_client): + """Test initialization with a custom intent provider.""" + provider = FakeIntentProvider("custom") + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + assert plugin._intent_provider is provider + assert plugin._mcp_client is mock_mcp_client + + def test_init_default_intent_provider(self, mock_mcp_client): + """Test initialization uses DefaultIntentProvider when none provided.""" + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client) + assert isinstance(plugin._intent_provider, DefaultIntentProvider) + + def test_plugin_name(self, plugin): + """Test plugin has correct name.""" + assert plugin.name == "agentcore-tool-search-plugin" + + def test_tools_property_returns_empty(self, plugin): + """Test tools property returns empty list.""" + assert plugin.tools == [] + + +class TestOnBeforeInvocation: + """Test on_before_invocation hook behavior.""" + + def test_empty_intent_skips_search(self, mock_mcp_client, mock_event): + """Test that empty intent does not call gateway search.""" + provider = FakeIntentProvider("") + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + + plugin.on_before_invocation(mock_event) + + mock_mcp_client.call_tool_sync.assert_not_called() + + def test_calls_gateway_search_with_intent(self, plugin, mock_mcp_client, mock_event): + """Test that derived intent is passed to gateway search.""" + plugin.on_before_invocation(mock_event) + + mock_mcp_client.call_tool_sync.assert_called_once_with( + tool_use_id="intent-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": "get weather"}, + ) + + def test_passes_agent_model_to_intent_provider(self, mock_mcp_client, mock_event): + """Test that the agent's model is passed to derive_intent.""" + provider = Mock(spec=IntentProvider) + provider.derive_intent.return_value = "" + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + mock_event.agent.model = Mock(name="test-model") + + plugin.on_before_invocation(mock_event) + + provider.derive_intent.assert_called_once_with(mock_event.messages, model=mock_event.agent.model) + + def test_registers_tools_from_structured_content(self, plugin, mock_mcp_client, mock_event): + """Test tools are registered from structuredContent response.""" + mock_mcp_client.call_tool_sync.return_value = { + "structuredContent": { + "tools": [ + { + "name": "weather_tool", + "description": "Get weather", + "inputSchema": {"type": "object", "properties": {"city": {"type": "string"}}}, + } + ] + } + } + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_called_once() + registered_tool = mock_event.agent.tool_registry.register_tool.call_args[0][0] + assert registered_tool.tool_name == "weather_tool" + assert "weather_tool" in plugin._loaded_tool_names + + def test_registers_tools_from_text_content(self, plugin, mock_mcp_client, mock_event): + """Test tools are registered from JSON text content response.""" + tools_json = json.dumps( + { + "tools": [ + { + "name": "calc_tool", + "description": "Calculator", + "inputSchema": {"type": "object", "properties": {}}, + }, + ] + } + ) + mock_mcp_client.call_tool_sync.return_value = {"content": [{"text": tools_json}]} + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_called_once() + registered_tool = mock_event.agent.tool_registry.register_tool.call_args[0][0] + assert registered_tool.tool_name == "calc_tool" + + def test_clears_previously_loaded_tools(self, plugin, mock_mcp_client, mock_event): + """Test previously loaded tools are removed from registry.""" + mock_mcp_client.call_tool_sync.return_value = {"content": []} + plugin._loaded_tool_names = {"old_tool_1", "old_tool_2"} + mock_event.agent.tool_registry.registry = { + "old_tool_1": Mock(), + "old_tool_2": Mock(), + "permanent_tool": Mock(), + } + + plugin.on_before_invocation(mock_event) + + assert "old_tool_1" not in mock_event.agent.tool_registry.registry + assert "old_tool_2" not in mock_event.agent.tool_registry.registry + assert "permanent_tool" in mock_event.agent.tool_registry.registry + assert len(plugin._loaded_tool_names) == 0 + + def test_gateway_search_failure_logs_and_returns(self, plugin, mock_mcp_client, mock_event): + """Test gateway search failure is handled gracefully.""" + mock_mcp_client.call_tool_sync.side_effect = RuntimeError("connection failed") + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_not_called() + + def test_skips_invalid_tool_defs(self, plugin, mock_mcp_client, mock_event): + """Test malformed tool definitions are skipped.""" + mock_mcp_client.call_tool_sync.return_value = { + "structuredContent": { + "tools": [ + {"description": "no name field"}, + "not a dict", + {"name": "valid_tool", "description": "ok", "inputSchema": {"type": "object", "properties": {}}}, + ] + } + } + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_called_once() + registered_tool = mock_event.agent.tool_registry.register_tool.call_args[0][0] + assert registered_tool.tool_name == "valid_tool" + + def test_register_tool_failure_continues(self, plugin, mock_mcp_client, mock_event): + """Test that failure to register one tool doesn't block others.""" + mock_mcp_client.call_tool_sync.return_value = { + "structuredContent": { + "tools": [ + {"name": "tool_a", "description": "A", "inputSchema": {"type": "object", "properties": {}}}, + {"name": "tool_b", "description": "B", "inputSchema": {"type": "object", "properties": {}}}, + ] + } + } + mock_event.agent.tool_registry.register_tool.side_effect = [RuntimeError("fail"), None] + + plugin.on_before_invocation(mock_event) + + assert mock_event.agent.tool_registry.register_tool.call_count == 2 + assert "tool_a" not in plugin._loaded_tool_names + assert "tool_b" in plugin._loaded_tool_names + + def test_none_result_loads_no_tools(self, plugin, mock_mcp_client, mock_event): + """Test None result from gateway loads no tools.""" + mock_mcp_client.call_tool_sync.return_value = None + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_not_called() + + def test_empty_messages_with_intent(self, mock_mcp_client): + """Test plugin works with empty messages list.""" + provider = FakeIntentProvider("") + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + event = Mock() + event.messages = [] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + + plugin.on_before_invocation(event) + + mock_mcp_client.call_tool_sync.assert_not_called() diff --git a/tests/bedrock_agentcore/gateway/integrations/strands/test_intent_providers.py b/tests/bedrock_agentcore/gateway/integrations/strands/test_intent_providers.py new file mode 100644 index 00000000..6b524245 --- /dev/null +++ b/tests/bedrock_agentcore/gateway/integrations/strands/test_intent_providers.py @@ -0,0 +1,211 @@ +"""Tests for IntentProvider and DefaultIntentProvider.""" + +from unittest.mock import Mock, patch + +import pytest + +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import ( + DefaultIntentProvider, + IntentProvider, +) + + +class TestIntentProviderInterface: + """Test IntentProvider abstract interface.""" + + def test_cannot_instantiate_abstract_class(self): + """Test that IntentProvider cannot be instantiated directly.""" + with pytest.raises(TypeError): + IntentProvider() + + def test_subclass_must_implement_derive_intent(self): + """Test that subclass without derive_intent raises TypeError.""" + + class IncompleteProvider(IntentProvider): + pass + + with pytest.raises(TypeError): + IncompleteProvider() + + def test_subclass_with_derive_intent_works(self): + """Test that a proper subclass can be instantiated.""" + + class ValidProvider(IntentProvider): + def derive_intent(self, messages: list[dict], model=None) -> str: + return "test" + + provider = ValidProvider() + assert provider.derive_intent([]) == "test" + + +class TestDefaultIntentProvider: + """Test DefaultIntentProvider class.""" + + def test_init_default_message_window(self): + """Test default message window is 5.""" + provider = DefaultIntentProvider() + assert provider._message_window == 5 + + def test_init_custom_message_window(self): + """Test custom message window.""" + provider = DefaultIntentProvider(message_window=3) + assert provider._message_window == 3 + + def test_init_with_explicit_model(self): + """Test initialization with explicit model.""" + model = Mock() + provider = DefaultIntentProvider(model=model) + assert provider._explicit_model is model + + def test_empty_messages_returns_empty_string(self): + """Test empty messages returns empty string without calling LLM.""" + provider = DefaultIntentProvider() + result = provider.derive_intent([]) + assert result == "" + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.default_intent_provider.Agent" + ) + def test_derive_intent_calls_agent(self, mock_agent_class): + """Test derive_intent creates an Agent and calls it.""" + mock_agent = Mock() + mock_agent.return_value = "user wants weather info" + mock_agent_class.return_value = mock_agent + + provider = DefaultIntentProvider(message_window=2) + messages = [ + {"role": "user", "content": [{"text": "What is the weather?"}]}, + ] + + result = provider.derive_intent(messages) + + assert result == "user wants weather info" + mock_agent_class.assert_called_once() + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.default_intent_provider.Agent" + ) + def test_derive_intent_uses_explicit_model(self, mock_agent_class): + """Test derive_intent uses explicit model over agent model.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + explicit_model = Mock(name="explicit-model") + provider = DefaultIntentProvider(model=explicit_model) + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + provider.derive_intent(messages, model=Mock(name="agent-model")) + + # Explicit model takes priority + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs["model"] is explicit_model + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.default_intent_provider.Agent" + ) + def test_derive_intent_uses_agent_model_when_no_explicit(self, mock_agent_class): + """Test derive_intent falls back to agent model when no explicit model.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + agent_model = Mock(name="agent-model") + provider = DefaultIntentProvider() + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + provider.derive_intent(messages, model=agent_model) + + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs["model"] is agent_model + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.default_intent_provider.Agent" + ) + def test_derive_intent_no_model_kwarg_when_none(self, mock_agent_class): + """Test derive_intent omits model kwarg when no model available.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + provider = DefaultIntentProvider() + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + provider.derive_intent(messages, model=None) + + call_kwargs = mock_agent_class.call_args[1] + assert "model" not in call_kwargs + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.default_intent_provider.Agent" + ) + def test_derive_intent_respects_message_window(self, mock_agent_class): + """Test only last N messages are used.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + provider = DefaultIntentProvider(message_window=2) + messages = [ + {"role": "user", "content": [{"text": "first"}]}, + {"role": "assistant", "content": [{"text": "second"}]}, + {"role": "user", "content": [{"text": "third"}]}, + {"role": "assistant", "content": [{"text": "fourth"}]}, + {"role": "user", "content": [{"text": "fifth"}]}, + ] + + provider.derive_intent(messages) + + # The formatted prompt should only contain the last 2 messages + call_args = mock_agent.call_args[0][0] + assert "first" not in call_args + assert "fourth" in call_args + assert "fifth" in call_args + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.default_intent_provider.Agent" + ) + def test_derive_intent_handles_exception(self, mock_agent_class): + """Test derive_intent returns empty string on exception.""" + mock_agent_class.side_effect = RuntimeError("LLM unavailable") + + provider = DefaultIntentProvider() + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + result = provider.derive_intent(messages) + + assert result == "" + + def test_format_messages_for_prompt(self): + """Test message formatting for the LLM prompt.""" + provider = DefaultIntentProvider() + messages = [ + {"role": "user", "content": [{"text": "Hello"}, {"text": "world"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + result = provider._format_messages_for_prompt(messages) + + assert "user: Hello world" in result + assert "assistant: Hi there" in result + + def test_format_messages_handles_missing_role(self): + """Test formatting handles messages without role.""" + provider = DefaultIntentProvider() + messages = [{"content": [{"text": "no role"}]}] + + result = provider._format_messages_for_prompt(messages) + + assert "unknown: no role" in result + + def test_format_messages_handles_non_text_blocks(self): + """Test formatting skips non-text content blocks.""" + provider = DefaultIntentProvider() + messages = [ + {"role": "user", "content": [{"image": "data"}, {"text": "only this"}]}, + ] + + result = provider._format_messages_for_prompt(messages) + + assert "only this" in result + assert "data" not in result diff --git a/tests_integ/gateway/integrations/lambda_function/lambda_function.py b/tests_integ/gateway/integrations/lambda_function/lambda_function.py new file mode 100644 index 00000000..2001f9b7 --- /dev/null +++ b/tests_integ/gateway/integrations/lambda_function/lambda_function.py @@ -0,0 +1,117 @@ +"""MCP-compatible Lambda handler for AgentCore Gateway integration tests. + +This Lambda implements the MCP JSON-RPC protocol over HTTP, responding to: +- initialize: Returns server capabilities +- tools/list: Returns available tool definitions +- tools/call: Executes a tool and returns results + +Deploy with Python 3.10+ runtime, handler: lambda_function.lambda_handler +""" + +import json +import logging + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +TOOLS = [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + }, + "required": ["city"], + }, + }, + { + "name": "send_email", + "description": "Send an email to a recipient", + "inputSchema": { + "type": "object", + "properties": { + "to": {"type": "string", "description": "Recipient email"}, + "subject": {"type": "string", "description": "Email subject"}, + "body": {"type": "string", "description": "Email body"}, + }, + "required": ["to", "subject", "body"], + }, + }, +] + + +def lambda_handler(event, context): + """Handle MCP JSON-RPC requests from AgentCore Gateway.""" + logger.info("Received event: %s", json.dumps(event)) + + body = event.get("body", "{}") + if isinstance(body, str): + body = json.loads(body) + + method = body.get("method", "") + request_id = body.get("id") + params = body.get("params", {}) + + if method == "initialize": + result = { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {"listChanged": False}, + }, + "serverInfo": { + "name": "integ-test-mcp-server", + "version": "1.0.0", + }, + } + elif method == "notifications/initialized": + # Client acknowledgment, no response needed + return {"statusCode": 200, "body": ""} + elif method == "tools/list": + result = {"tools": TOOLS} + elif method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + result = _handle_tool_call(tool_name, arguments) + else: + return { + "statusCode": 200, + "body": json.dumps( + { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32601, + "message": f"Method not found: {method}", + }, + } + ), + } + + response = { + "jsonrpc": "2.0", + "id": request_id, + "result": result, + } + + return { + "statusCode": 200, + "body": json.dumps(response), + } + + +def _handle_tool_call(tool_name, arguments): + """Execute a tool and return MCP-formatted result.""" + if tool_name == "get_weather": + city = arguments.get("city", "unknown") + return {"content": [{"type": "text", "text": f"Weather in {city}: 72°F, sunny with light clouds."}]} + elif tool_name == "send_email": + to = arguments.get("to", "") + subject = arguments.get("subject", "") + return {"content": [{"type": "text", "text": f"Email sent to {to} with subject: {subject}"}]} + else: + return { + "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], + "isError": True, + } diff --git a/tests_integ/gateway/integrations/test_agentcore_tool_search_plugin.py b/tests_integ/gateway/integrations/test_agentcore_tool_search_plugin.py new file mode 100644 index 00000000..e709a3de --- /dev/null +++ b/tests_integ/gateway/integrations/test_agentcore_tool_search_plugin.py @@ -0,0 +1,363 @@ +"""Integration tests for AgentCoreToolSearchPlugin. + +Requires environment variables: + BEDROCK_TEST_REGION: AWS region (default: us-west-2) + GATEWAY_ROLE_ARN: IAM role ARN with AgentCore gateway trust policy + GATEWAY_LAMBDA_ARN: Lambda ARN for the gateway target (must implement MCP tool handler) + +Prerequisites: + 1. Deploy the Lambda in tests_integ/gateway/integrations/lambda_function/lambda_function.py + (Python 3.10+ runtime, handler: lambda_function.lambda_handler) + + 2. The GATEWAY_ROLE_ARN must have: + - Trust policy for bedrock-agentcore.amazonaws.com + - lambda:InvokeFunction permission on the GATEWAY_LAMBDA_ARN + + Example inline policy: + { + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Action": "lambda:InvokeFunction", + "Resource": "" + }] + } + + 3. Install mcp-proxy-for-aws: uv pip install mcp-proxy-for-aws + +""" + +import logging +import os +import time + +import pytest + +from bedrock_agentcore.gateway.client import GatewayClient +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import ( + DefaultIntentProvider, + IntentProvider, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class FixedIntentProvider(IntentProvider): + """Intent provider that returns a fixed string for deterministic testing.""" + + def __init__(self, intent: str): + self._intent = intent + + def derive_intent(self, messages: list[dict], model=None) -> str: + return self._intent + + +@pytest.mark.integration +class TestAgentCoreToolSearchPluginIntegration: + """Integration tests for AgentCoreToolSearchPlugin with a live gateway. + + Creates a gateway with a Lambda target exposing test tools, then verifies + the plugin can search and load those tools. + """ + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.role_arn = os.environ.get("GATEWAY_ROLE_ARN") + cls.lambda_arn = os.environ.get("GATEWAY_LAMBDA_ARN") + + if not cls.role_arn or not cls.lambda_arn: + pytest.fail("GATEWAY_ROLE_ARN and GATEWAY_LAMBDA_ARN must be set") + + cls.gw_client = GatewayClient(region_name=cls.region) + cls.test_prefix = f"sdk-integ-plugin-{int(time.time())}" + cls.gateway_id = None + cls.target_id = None + + # Create gateway with semantic search enabled + gw = cls.gw_client.create_gateway_and_wait( + name=f"{cls.test_prefix}-gw", + roleArn=cls.role_arn, + authorizerType="NONE", + protocolType="MCP", + protocolConfiguration={ + "mcp": { + "searchType": "SEMANTIC", + }, + }, + ) + cls.gateway_id = gw["gatewayId"] + logger.info("Created gateway: %s", cls.gateway_id) + + # Create target with test tools + target = cls.gw_client.create_gateway_target_and_wait( + gatewayIdentifier=cls.gateway_id, + name=f"{cls.test_prefix}-target", + targetConfiguration={ + "mcp": { + "lambda": { + "lambdaArn": cls.lambda_arn, + "toolSchema": { + "inlinePayload": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + }, + "required": ["city"], + }, + }, + { + "name": "send_email", + "description": "Send an email to a recipient", + "inputSchema": { + "type": "object", + "properties": { + "to": {"type": "string"}, + "subject": {"type": "string"}, + "body": {"type": "string"}, + }, + "required": ["to", "subject", "body"], + }, + }, + ] + }, + } + }, + }, + credentialProviderConfigurations=[ + {"credentialProviderType": "GATEWAY_IAM_ROLE"}, + ], + ) + cls.target_id = target["targetId"] + logger.info("Created target: %s", cls.target_id) + + # Wait for target search indexing to complete (can take up to 60s) + time.sleep(60) + + @classmethod + def teardown_class(cls): + if cls.gateway_id: + if cls.target_id: + try: + cls.gw_client.delete_gateway_target_and_wait( + gatewayIdentifier=cls.gateway_id, + targetId=cls.target_id, + ) + except Exception as e: + logger.warning("Failed to delete target %s: %s", cls.target_id, e) + try: + cls.gw_client.delete_gateway_and_wait( + gatewayIdentifier=cls.gateway_id, + ) + except Exception as e: + logger.warning("Failed to delete gateway %s: %s", cls.gateway_id, e) + + def _make_mcp_client(self): + """Create an MCPClient connected to the test gateway via Streamable HTTP with IAM auth.""" + from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client + from strands.tools.mcp import MCPClient + + endpoint = f"https://{self.gateway_id}.gateway.bedrock-agentcore.{self.region}.amazonaws.com/mcp" + return MCPClient( + lambda: aws_iam_streamablehttp_client( + endpoint=endpoint, + aws_region=self.region, + aws_service="bedrock-agentcore", + ) + ) + + @pytest.mark.order(1) + def test_plugin_with_default_intent_provider(self): + """Plugin initializes correctly with DefaultIntentProvider.""" + mcp_client = self._make_mcp_client() + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client) + assert isinstance(plugin._intent_provider, DefaultIntentProvider) + assert plugin.name == "agentcore-tool-search-plugin" + + @pytest.mark.order(2) + def test_plugin_with_custom_intent_provider(self): + """Plugin accepts a custom IntentProvider.""" + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("weather query") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + assert plugin._intent_provider is provider + + @pytest.mark.order(3) + def test_gateway_search_returns_results(self): + """Calling x_amz_bedrock_agentcore_search on the gateway returns tool definitions.""" + mcp_client = self._make_mcp_client() + + with mcp_client: + result = mcp_client.call_tool_sync( + tool_use_id="test-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": "get weather information"}, + ) + + assert result is not None + logger.info("Search result keys: %s", result.keys() if isinstance(result, dict) else type(result)) + + @pytest.mark.order(4) + def test_plugin_loads_tools_via_hook(self): + """Plugin loads matching tools into the agent via the before_invocation hook.""" + from strands import Agent + + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("get weather information") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + + with mcp_client: + # First verify the search endpoint returns tools + result = mcp_client.call_tool_sync( + tool_use_id="debug-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": "get weather information"}, + ) + logger.info("Raw search result: %s", result) + + agent = Agent( + system_prompt="You are a helpful assistant. Use available tools to help the user.", + tools=[], + plugins=[plugin], + ) + # Trigger an invocation so the hook fires + agent("What is the weather in Seattle?") + + logger.info("Loaded tool names: %s", plugin._loaded_tool_names) + # The gateway should have returned the get_weather tool + assert len(plugin._loaded_tool_names) > 0, ( + f"Expected tools to be loaded but got none. Raw search result was: {result}" + ) + + @pytest.mark.order(5) + def test_empty_intent_loads_no_tools(self): + """Plugin does not search gateway when intent is empty.""" + from unittest.mock import Mock + + from strands.hooks import BeforeInvocationEvent + + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + + # Simulate a before_invocation event + event = Mock(spec=BeforeInvocationEvent) + event.messages = [{"role": "user", "content": [{"text": "hello"}]}] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + + plugin.on_before_invocation(event) + + assert len(plugin._loaded_tool_names) == 0 + + @pytest.mark.order(6) + def test_tools_cleared_between_invocations(self): + """Previously loaded tools are cleared before each new search.""" + from unittest.mock import Mock + + from strands.hooks import BeforeInvocationEvent + + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("get weather information") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + + with mcp_client: + # First: simulate invocation with a real intent + event = Mock(spec=BeforeInvocationEvent) + event.messages = [{"role": "user", "content": [{"text": "weather"}]}] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + + plugin.on_before_invocation(event) + first_tools = set(plugin._loaded_tool_names) + logger.info("First invocation tools: %s", first_tools) + assert len(first_tools) > 0 + + # Second: switch to empty intent — tools should be cleared + provider._intent = "" + event.agent.tool_registry.registry = {name: Mock() for name in first_tools} + + plugin.on_before_invocation(event) + second_tools = set(plugin._loaded_tool_names) + logger.info("Second invocation tools: %s", second_tools) + + assert len(second_tools) == 0 + # Verify old tools were removed from registry + for name in first_tools: + assert name not in event.agent.tool_registry.registry + + +@pytest.mark.integration +class TestDefaultIntentProviderIntegration: + """Integration tests for DefaultIntentProvider with a real LLM.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + + def test_derive_intent_from_messages(self): + """DefaultIntentProvider produces a non-empty intent string from messages.""" + provider = DefaultIntentProvider(message_window=3) + messages = [ + {"role": "user", "content": [{"text": "What's the weather like in Seattle today?"}]}, + {"role": "assistant", "content": [{"text": "Let me check the weather for you."}]}, + {"role": "user", "content": [{"text": "Also check tomorrow's forecast."}]}, + ] + + intent = provider.derive_intent(messages) + + logger.info("Derived intent: %s", intent) + assert isinstance(intent, str) + assert len(intent) > 0 + + def test_derive_intent_empty_messages(self): + """DefaultIntentProvider returns empty string for empty messages.""" + provider = DefaultIntentProvider() + intent = provider.derive_intent([]) + assert intent == "" + + def test_derive_intent_with_custom_model(self): + """DefaultIntentProvider works with an explicitly provided model.""" + from strands.models.bedrock import BedrockModel + + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + region_name=self.region, + ) + provider = DefaultIntentProvider(model=model) + messages = [ + {"role": "user", "content": [{"text": "I need to send an email to my team about the project update."}]}, + ] + + intent = provider.derive_intent(messages) + + logger.info("Derived intent with custom model: %s", intent) + assert isinstance(intent, str) + assert len(intent) > 0 + + def test_derive_intent_respects_message_window(self): + """DefaultIntentProvider only considers the last N messages.""" + provider = DefaultIntentProvider(message_window=2) + messages = [ + {"role": "user", "content": [{"text": "Tell me about dogs."}]}, + {"role": "assistant", "content": [{"text": "Dogs are great pets."}]}, + {"role": "user", "content": [{"text": "Now tell me about the stock market."}]}, + {"role": "assistant", "content": [{"text": "The stock market is complex."}]}, + {"role": "user", "content": [{"text": "What are the best investment strategies?"}]}, + ] + + intent = provider.derive_intent(messages) + + logger.info("Derived intent (window=2): %s", intent) + assert isinstance(intent, str) + assert len(intent) > 0