diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..2e01d6b9f 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -431,9 +431,10 @@ async def _receive_loop(self) -> None: def _normalize_request_id(self, response_id: RequestId) -> RequestId: """Normalize a response ID to match how request IDs are stored. - Since the client always sends integer IDs, we normalize string IDs - to integers when possible. This matches the TypeScript SDK approach: - https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861 + Since the client always sends integer IDs, we normalize canonical numeric + string IDs (e.g. ``"0"``, ``"42"``) to integers. Non-canonical numeric + strings (e.g. ``"01"``, ``"+1"``) are left as strings to avoid collisions + with integer IDs. Args: response_id: The response ID from the incoming message. @@ -443,9 +444,18 @@ def _normalize_request_id(self, response_id: RequestId) -> RequestId: """ if isinstance(response_id, str): try: - return int(response_id) + int_id = int(response_id) except ValueError: logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") + return response_id + + if str(int_id) == response_id: + return int_id + + logging.warning( + "Response ID %r is numeric but non-canonical; not normalizing to avoid ID collisions", + response_id, + ) return response_id async def _handle_response(self, message: SessionMessage) -> None: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5..d186ef5f6 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -258,6 +258,76 @@ async def make_request(client_session: ClientSession): await ev_timeout.wait() +@pytest.mark.anyio +async def test_response_id_non_canonical_numeric_string_no_match(): + """Test that non-canonical numeric IDs don't collide with integer request IDs. + + If a server returns ``"id": "01"``, it should not match a pending request with + integer ID ``1``. + """ + ev_timeout = anyio.Event() + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Respond to ping #1, then send a non-canonical ID for ping #2.""" + first_message = await server_read.receive() + assert isinstance(first_message, SessionMessage) + assert isinstance(first_message.message, JSONRPCRequest) + first_request_id = first_message.message.id + + # Let the first request complete so the second request is sent. + await server_write.send( + SessionMessage( + message=JSONRPCResponse( + jsonrpc="2.0", + id=first_request_id, + result={}, + ) + ) + ) + + second_message = await server_read.receive() + assert isinstance(second_message, SessionMessage) + assert isinstance(second_message.message, JSONRPCRequest) + second_request_id = second_message.message.id + assert second_request_id == 1 + + response = JSONRPCResponse( + jsonrpc="2.0", + id="01", # Non-canonical representation of 1 + result={}, + ) + await server_write.send(SessionMessage(message=response)) + + async def make_requests(client_session: ClientSession): + # First request consumes request ID 0 so the second request uses ID 1. + await client_session.send_ping() + + try: + await client_session.send_request( + types.PingRequest(), + types.EmptyResult, + request_read_timeout_seconds=0.5, + ) + pytest.fail("Expected timeout") # pragma: no cover + except MCPError as e: + assert "Timed out" in str(e) + ev_timeout.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_requests, client_session) + + with anyio.fail_after(2): # pragma: no branch + await ev_timeout.wait() + + @pytest.mark.anyio async def test_connection_closed(): """Test that pending requests are cancelled when the connection is closed remotely."""