diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index ece862a..369aae4 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -1,10 +1,11 @@ from __future__ import annotations import base64 +import contextlib import functools import inspect from abc import ABC, abstractmethod -from asyncio import CancelledError, sleep +from asyncio import CancelledError, Event, create_task, sleep from dataclasses import replace from http import HTTPStatus from typing import TYPE_CHECKING, Generic, TypeVar, cast @@ -385,6 +386,9 @@ async def _handle_stream( self._read_max_bytes, ) + disconnect_detected: Event | None = None + monitor_task = None + match endpoint: case EndpointUnary(): request = await _consume_single_request(request_stream) @@ -396,22 +400,50 @@ async def _handle_stream( case EndpointServerStream(): request = await _consume_single_request(request_stream) response_stream = endpoint.function(request, ctx) + + # The request has been fully consumed; monitor receive() for a + # client disconnect so we can stop streaming promptly. + disconnect_detected = Event() + + async def _watch_for_disconnect() -> None: + while True: + msg = await receive() + if msg["type"] == "http.disconnect": + disconnect_detected.set() + return + + monitor_task = create_task(_watch_for_disconnect()) case EndpointBidiStream(): response_stream = endpoint.function(request_stream, ctx) - async for message in response_stream: - # Don't send headers until the first message to allow logic a chance to add - # response headers. - if not sent_headers: - await _send_stream_response_headers( - send, protocol, codec, resp_compression.name(), ctx + try: + async for message in response_stream: + if disconnect_detected is not None and disconnect_detected.is_set(): + raise ConnectError(Code.CANCELED, "Client disconnected") + # Don't send headers until the first message to allow logic a chance to add + # response headers. + if not sent_headers: + await _send_stream_response_headers( + send, protocol, codec, resp_compression.name(), ctx + ) + sent_headers = True + + body = writer.write(message) + await send( + {"type": "http.response.body", "body": body, "more_body": True} ) - sent_headers = True - - body = writer.write(message) - await send( - {"type": "http.response.body", "body": body, "more_body": True} - ) + finally: + # Cancel the monitor first so a throwing generator finally-block + # doesn't leak the task. + if monitor_task is not None: + monitor_task.cancel() + with contextlib.suppress(CancelledError): + await monitor_task + # Explicitly close the stream so that any generator finally-blocks + # run promptly (Python defers async-generator cleanup to GC otherwise). + aclose = getattr(response_stream, "aclose", None) + if aclose is not None: + await aclose() except CancelledError as e: raise ConnectError(Code.CANCELED, "Request was cancelled") from e except Exception as e: diff --git a/test/test_roundtrip.py b/test/test_roundtrip.py index 56a4f13..406d07c 100644 --- a/test/test_roundtrip.py +++ b/test/test_roundtrip.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import struct from typing import TYPE_CHECKING import pytest @@ -23,6 +25,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator + from asgiref.typing import HTTPDisconnectEvent, HTTPRequestEvent, HTTPScope + @pytest.mark.parametrize("proto_json", [False, True]) @pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"]) @@ -280,3 +284,76 @@ async def request_stream(): else: assert len(requests) == 2 assert len(responses) == 1 + + +@pytest.mark.asyncio +async def test_server_stream_client_disconnect() -> None: + """Server streaming generator should be closed when the client disconnects. + + Regression test for https://github.com/connectrpc/connect-python/issues/174. + """ + generator_closed = asyncio.Event() + + class InfiniteHaberdasher(Haberdasher): + async def make_similar_hats(self, request, ctx): + try: + while True: + yield Hat(size=request.inches, color="green") + await asyncio.sleep(0) # yield control to event loop + finally: + generator_closed.set() + + app = HaberdasherASGIApplication(InfiniteHaberdasher()) + + # Encode a Connect protocol (application/connect+proto) request for Size(inches=10). + request_bytes = Size(inches=10).SerializeToString() + request_body = struct.pack(">BI", 0, len(request_bytes)) + request_bytes + + # We invoke the ASGI app directly rather than using a real client with a + # short timeout because a real client could trigger the disconnect before the + # request body has been fully read, which would be a different code path. + disconnect_trigger = asyncio.Event() + response_count = 0 + call_count = 0 + + async def receive() -> HTTPRequestEvent | HTTPDisconnectEvent: + nonlocal call_count + call_count += 1 + if call_count == 1: + return {"type": "http.request", "body": request_body, "more_body": False} + # Block until the test is ready to simulate a disconnect. + await disconnect_trigger.wait() + return {"type": "http.disconnect"} + + async def send(message): + nonlocal response_count + if message.get("type") == "http.response.body" and message.get( + "more_body", False + ): + response_count += 1 + if response_count >= 3: + disconnect_trigger.set() + + scope: HTTPScope = { + "type": "http", + "asgi": {"spec_version": "2.0", "version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/connectrpc.example.Haberdasher/MakeSimilarHats", + "raw_path": b"/connectrpc.example.Haberdasher/MakeSimilarHats", + "query_string": b"", + "root_path": "", + "headers": [(b"content-type", b"application/connect+proto")], + "client": None, + "server": None, + "extensions": None, + } + + # Without the fix the app hangs forever (generator never stopped), causing a + # TimeoutError here. With the fix it terminates promptly after the disconnect. + await asyncio.wait_for(app(scope, receive, send), timeout=5.0) + + assert generator_closed.is_set(), ( + "generator should be closed after client disconnect" + )