From 3e2fb4f8f0dce00ba6898dbcc4e40fddf15aeb92 Mon Sep 17 00:00:00 2001 From: BabyChrist666 Date: Wed, 25 Feb 2026 12:43:01 -0500 Subject: [PATCH] feat: add public API for runtime handler registration/deregistration Add add_request_handler(), remove_request_handler(), add_notification_handler(), remove_notification_handler(), and has_handler() as public methods on the low-level Server class. This enables frameworks and advanced use cases to register handlers for protocol extensions or custom methods after server construction, and to remove or replace handlers dynamically. Refactors ExperimentalHandlers to use the new public API instead of receiving private method references, validating the API with its first internal consumer. Fixes #2135 Co-Authored-By: Claude Opus 4.6 --- src/mcp/server/lowlevel/experimental.py | 45 ++++----- src/mcp/server/lowlevel/server.py | 71 +++++++++++++- .../lowlevel/test_handler_registration.py | 94 +++++++++++++++++++ 3 files changed, 184 insertions(+), 26 deletions(-) create mode 100644 tests/server/lowlevel/test_handler_registration.py diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 5a907b640..ae9667a1c 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -7,7 +7,7 @@ import logging from collections.abc import Awaitable, Callable -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar @@ -38,6 +38,9 @@ TasksToolsCapability, ) +if TYPE_CHECKING: + from mcp.server.lowlevel.server import Server + logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) @@ -51,13 +54,9 @@ class ExperimentalHandlers(Generic[LifespanResultT]): def __init__( self, - add_request_handler: Callable[ - [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None - ], - has_handler: Callable[[str], bool], + server: Server[LifespanResultT, Any], ) -> None: - self._add_request_handler = add_request_handler - self._has_handler = has_handler + self._server = server self._task_support: TaskSupport | None = None @property @@ -67,13 +66,15 @@ def task_support(self) -> TaskSupport | None: def update_capabilities(self, capabilities: ServerCapabilities) -> None: # Only add tasks capability if handlers are registered - if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): + if not any( + self._server.has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"] + ): return capabilities.tasks = ServerTasksCapability() - if self._has_handler("tasks/list"): + if self._server.has_handler("tasks/list"): capabilities.tasks.list = TasksListCapability() - if self._has_handler("tasks/cancel"): + if self._server.has_handler("tasks/cancel"): capabilities.tasks.cancel = TasksCancelCapability() capabilities.tasks.requests = ServerTasksRequestsCapability( @@ -145,16 +146,16 @@ def enable_tasks( # Register user-provided handlers if on_get_task is not None: - self._add_request_handler("tasks/get", on_get_task) + self._server.add_request_handler("tasks/get", on_get_task) if on_task_result is not None: - self._add_request_handler("tasks/result", on_task_result) + self._server.add_request_handler("tasks/result", on_task_result) if on_list_tasks is not None: - self._add_request_handler("tasks/list", on_list_tasks) + self._server.add_request_handler("tasks/list", on_list_tasks) if on_cancel_task is not None: - self._add_request_handler("tasks/cancel", on_cancel_task) + self._server.add_request_handler("tasks/cancel", on_cancel_task) # Fill in defaults for any not provided - if not self._has_handler("tasks/get"): + if not self._server.has_handler("tasks/get"): async def _default_get_task( ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams @@ -172,9 +173,9 @@ async def _default_get_task( poll_interval=task.poll_interval, ) - self._add_request_handler("tasks/get", _default_get_task) + self._server.add_request_handler("tasks/get", _default_get_task) - if not self._has_handler("tasks/result"): + if not self._server.has_handler("tasks/result"): async def _default_get_task_result( ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams @@ -184,9 +185,9 @@ async def _default_get_task_result( result = await task_support.handler.handle(req, ctx.session, ctx.request_id) return result - self._add_request_handler("tasks/result", _default_get_task_result) + self._server.add_request_handler("tasks/result", _default_get_task_result) - if not self._has_handler("tasks/list"): + if not self._server.has_handler("tasks/list"): async def _default_list_tasks( ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None @@ -195,9 +196,9 @@ async def _default_list_tasks( tasks, next_cursor = await task_support.store.list_tasks(cursor) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - self._add_request_handler("tasks/list", _default_list_tasks) + self._server.add_request_handler("tasks/list", _default_list_tasks) - if not self._has_handler("tasks/cancel"): + if not self._server.has_handler("tasks/cancel"): async def _default_cancel_task( ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams @@ -205,6 +206,6 @@ async def _default_cancel_task( result = await cancel_task(task_support.store, params.task_id) return result - self._add_request_handler("tasks/cancel", _default_cancel_task) + self._server.add_request_handler("tasks/cancel", _default_cancel_task) return task_support diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index aee644040..6ebcf679c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -246,6 +246,72 @@ def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + def add_request_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + ) -> None: + """Register a request handler for the given method. + + If a handler is already registered for this method, it will be replaced. + + Args: + method: The JSON-RPC method name (e.g., "tools/list", "myextension/query"). + handler: An async callable that takes (ServerRequestContext, params) and + returns the result. + """ + self._request_handlers[method] = handler + + def remove_request_handler(self, method: str) -> None: + """Remove the request handler for the given method. + + Args: + method: The JSON-RPC method name to deregister. + + Raises: + KeyError: If no handler is registered for this method. + """ + del self._request_handlers[method] + + def add_notification_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]], + ) -> None: + """Register a notification handler for the given method. + + If a handler is already registered for this method, it will be replaced. + + Args: + method: The JSON-RPC notification method name + (e.g., "notifications/progress"). + handler: An async callable that takes (ServerRequestContext, params) and + returns None. + """ + self._notification_handlers[method] = handler + + def remove_notification_handler(self, method: str) -> None: + """Remove the notification handler for the given method. + + Args: + method: The JSON-RPC notification method name to deregister. + + Raises: + KeyError: If no handler is registered for this method. + """ + del self._notification_handlers[method] + + def has_handler(self, method: str) -> bool: + """Check if a handler is registered for the given request or notification method. + + Args: + method: The JSON-RPC method name to check. + + Returns: + True if a handler is registered, False otherwise. + """ + return method in self._request_handlers or method in self._notification_handlers + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -336,10 +402,7 @@ def experimental(self) -> ExperimentalHandlers[LifespanResultT]: # We create this inline so we only add these capabilities _if_ they're actually used if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers( - add_request_handler=self._add_request_handler, - has_handler=self._has_handler, - ) + self._experimental_handlers = ExperimentalHandlers(server=self) return self._experimental_handlers @property diff --git a/tests/server/lowlevel/test_handler_registration.py b/tests/server/lowlevel/test_handler_registration.py new file mode 100644 index 000000000..37f9a3226 --- /dev/null +++ b/tests/server/lowlevel/test_handler_registration.py @@ -0,0 +1,94 @@ +"""Tests for public handler registration/deregistration API on low-level Server.""" + +import pytest + +from mcp.server.lowlevel.server import Server + + +@pytest.fixture +def server(): + return Server(name="test-server") + + +async def _dummy_request_handler(ctx, params): + return {"result": "ok"} + + +async def _dummy_notification_handler(ctx, params): + pass + + +class TestAddRequestHandler: + def test_add_request_handler(self, server): + server.add_request_handler("custom/method", _dummy_request_handler) + assert server.has_handler("custom/method") + + def test_add_request_handler_replaces_existing(self, server): + async def handler_a(ctx, params): + return "a" + + async def handler_b(ctx, params): + return "b" + + server.add_request_handler("custom/method", handler_a) + server.add_request_handler("custom/method", handler_b) + # The second handler should replace the first + assert server._request_handlers["custom/method"] is handler_b + + +class TestRemoveRequestHandler: + def test_remove_request_handler(self, server): + server.add_request_handler("custom/method", _dummy_request_handler) + assert server.has_handler("custom/method") + server.remove_request_handler("custom/method") + assert not server.has_handler("custom/method") + + def test_remove_request_handler_not_found(self, server): + with pytest.raises(KeyError): + server.remove_request_handler("nonexistent/method") + + +class TestAddNotificationHandler: + def test_add_notification_handler(self, server): + server.add_notification_handler("custom/notify", _dummy_notification_handler) + assert server.has_handler("custom/notify") + + def test_add_notification_handler_replaces_existing(self, server): + async def handler_a(ctx, params): + pass + + async def handler_b(ctx, params): + pass + + server.add_notification_handler("custom/notify", handler_a) + server.add_notification_handler("custom/notify", handler_b) + assert server._notification_handlers["custom/notify"] is handler_b + + +class TestRemoveNotificationHandler: + def test_remove_notification_handler(self, server): + server.add_notification_handler("custom/notify", _dummy_notification_handler) + assert server.has_handler("custom/notify") + server.remove_notification_handler("custom/notify") + assert not server.has_handler("custom/notify") + + def test_remove_notification_handler_not_found(self, server): + with pytest.raises(KeyError): + server.remove_notification_handler("nonexistent/notify") + + +class TestHasHandler: + def test_has_handler_request(self, server): + server.add_request_handler("custom/method", _dummy_request_handler) + assert server.has_handler("custom/method") + + def test_has_handler_notification(self, server): + server.add_notification_handler("custom/notify", _dummy_notification_handler) + assert server.has_handler("custom/notify") + + def test_has_handler_unregistered(self, server): + assert not server.has_handler("nonexistent/method") + + def test_has_handler_default_ping(self, server): + """The ping handler is registered by default.""" + assert server.has_handler("ping")