From 3f9a2d1d0e957012e0dcd98e6628995f9e5966a9 Mon Sep 17 00:00:00 2001 From: chenzhenye Date: Wed, 25 Feb 2026 15:43:52 +0800 Subject: [PATCH] feat(transport-security): add subdomain wildcard support for allowed_hosts TransportSecuritySettings.allowed_hosts now supports *.domain patterns (e.g. *.mysite.com) so that a single entry can allow the base domain and any subdomain (app.mysite.com, api.mysite.com, etc.) instead of listing each host explicitly. This makes multi-subdomain or dynamic subdomain setups practical. - Add _hostname_from_host() to strip port from Host header (including IPv6) - In _validate_host(), treat entries starting with *. as subdomain wildcards: match hostname equal to base domain or ending with . - Preserve existing behaviour: exact match and example.com:* port wildcard - Document the three pattern types in allowed_hosts docstring - Add integration tests for SSE and StreamableHTTP with *.mysite.com Github-Issue: #2141 --- src/mcp/server/transport_security.py | 34 ++++++++++-- tests/server/test_sse_security.py | 53 +++++++++++++++++++ tests/server/test_streamable_http_security.py | 41 ++++++++++++++ tests/server/test_transport_security.py | 21 ++++++++ 4 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 tests/server/test_transport_security.py diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..a72366511 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -22,6 +22,13 @@ class TransportSecuritySettings(BaseModel): allowed_hosts: list[str] = Field(default_factory=list) """List of allowed Host header values. + Supports: + - Exact match: ``example.com``, ``127.0.0.1:8080`` + - Wildcard port: ``example.com:*`` matches ``example.com`` with any port + - Subdomain wildcard: ``*.mysite.com`` matches ``mysite.com`` and any subdomain + (e.g. ``app.mysite.com``, ``api.mysite.com``). Optionally use ``*.mysite.com:*`` + to also allow any port. + Only applies when `enable_dns_rebinding_protection` is `True`. """ @@ -40,6 +47,15 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) + def _hostname_from_host(self, host: str) -> str: + """Extract hostname from Host header (strip optional port).""" + if host.startswith("["): + idx = host.find("]:") + if idx != -1: + return host[: idx + 1] + return host + return host.split(":", 1)[0] + def _validate_host(self, host: str | None) -> bool: # pragma: no cover """Validate the Host header against allowed values.""" if not host: @@ -50,15 +66,27 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover if host in self.settings.allowed_hosts: return True - # Check wildcard port patterns + # Check wildcard port patterns (e.g. example.com:*) for allowed in self.settings.allowed_hosts: if allowed.endswith(":*"): - # Extract base host from pattern base_host = allowed[:-2] - # Check if the actual host starts with base host and has a port + # Subdomain pattern *.domain.com:* is handled below; skip here + if base_host.startswith("*."): + continue if host.startswith(base_host + ":"): return True + # Check subdomain wildcard patterns (e.g. *.mysite.com or *.mysite.com:*) + hostname = self._hostname_from_host(host) + for allowed in self.settings.allowed_hosts: + if allowed.startswith("*."): + pattern = allowed[:-2] if allowed.endswith(":*") else allowed + base_domain = pattern[2:] + if not base_domain: + continue + if hostname == base_domain or hostname.endswith("." + base_domain): + return True + logger.warning(f"Invalid Host header: {host}") return False diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a2..58379c6ad 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -256,6 +256,59 @@ async def test_sse_security_wildcard_ports(server_port: int): process.join() +@pytest.mark.anyio +async def test_sse_security_ipv6_host_header(server_port: int): + """Test SSE with IPv6 Host header ([::1] and [::1]:port) to cover _hostname_from_host.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1:*", "[::1]:*", "[::1]"], + allowed_origins=["http://127.0.0.1:*", "http://[::1]:*"], + ) + process = start_server_process(server_port, settings) + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + async with client.stream( + "GET", f"http://127.0.0.1:{server_port}/sse", headers={"Host": "[::1]:8080"} + ) as response: + assert response.status_code == 200 + async with client.stream( + "GET", f"http://127.0.0.1:{server_port}/sse", headers={"Host": "[::1]"} + ) as response: + assert response.status_code == 200 + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_subdomain_wildcard_host(server_port: int): + """Test SSE with *.domain subdomain wildcard in allowed_hosts (issue #2141).""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["*.mysite.com", "127.0.0.1:*"], + allowed_origins=["http://127.0.0.1:*", "http://app.mysite.com:*"], + ) + process = start_server_process(server_port, settings) + + try: + # Allowed: subdomain and base domain + for host in ["app.mysite.com", "api.mysite.com", "mysite.com"]: + headers = {"Host": host} + async with httpx.AsyncClient(timeout=5.0) as client: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + assert response.status_code == 200, f"Host {host} should be allowed" + + # Rejected: other domain + async with httpx.AsyncClient() as client: + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers={"Host": "other.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + finally: + process.terminate() + process.join() + + @pytest.mark.anyio async def test_sse_security_post_valid_content_type(server_port: int): """Test POST endpoint with valid Content-Type headers.""" diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353..1c7a3be72 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -253,6 +253,47 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): process.join() +@pytest.mark.anyio +async def test_streamable_http_security_subdomain_wildcard_host(server_port: int): + """Test StreamableHTTP with *.domain subdomain wildcard in allowed_hosts (issue #2141).""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["*.mysite.com", "127.0.0.1:*"], + allowed_origins=["http://127.0.0.1:*", "http://app.mysite.com:*"], + ) + process = start_server_process(server_port, settings) + + try: + headers = { + "Host": "app.mysite.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers, + ) + assert response.status_code == 200 + + async with httpx.AsyncClient() as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "other.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + finally: + process.terminate() + process.join() + + @pytest.mark.anyio async def test_streamable_http_security_get_request(server_port: int): """Test StreamableHTTP GET request with security.""" diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 000000000..1defdffb8 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,21 @@ +"""Tests for transport security (DNS rebinding protection).""" + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def test_hostname_from_host_ipv6_with_port(): + """_hostname_from_host strips port from [::1]:port (coverage for lines 52-55).""" + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert m._hostname_from_host("[::1]:8080") == "[::1]" + + +def test_hostname_from_host_ipv6_no_port(): + """_hostname_from_host returns [::1] as-is when no port (coverage for line 56).""" + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert m._hostname_from_host("[::1]") == "[::1]" + + +def test_hostname_from_host_plain_with_port(): + """_hostname_from_host strips port from hostname (coverage for line 57).""" + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert m._hostname_from_host("app.mysite.com:8080") == "app.mysite.com"