Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions src/connectrpc/_server_async.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import base64
import contextlib
import functools
import inspect
from abc import ABC, abstractmethod
from asyncio import CancelledError, sleep
from asyncio import CancelledError, Event, create_task, sleep
from dataclasses import replace
from http import HTTPStatus
from typing import TYPE_CHECKING, Generic, TypeVar, cast
Expand Down Expand Up @@ -385,6 +386,9 @@ async def _handle_stream(
self._read_max_bytes,
)

disconnect_detected: Event | None = None
monitor_task = None

match endpoint:
case EndpointUnary():
request = await _consume_single_request(request_stream)
Expand All @@ -396,22 +400,50 @@ async def _handle_stream(
case EndpointServerStream():
request = await _consume_single_request(request_stream)
response_stream = endpoint.function(request, ctx)

# The request has been fully consumed; monitor receive() for a
# client disconnect so we can stop streaming promptly.
disconnect_detected = Event()

async def _watch_for_disconnect() -> None:
while True:
msg = await receive()
if msg["type"] == "http.disconnect":
disconnect_detected.set()
return

monitor_task = create_task(_watch_for_disconnect())
case EndpointBidiStream():
response_stream = endpoint.function(request_stream, ctx)

async for message in response_stream:
# Don't send headers until the first message to allow logic a chance to add
# response headers.
if not sent_headers:
await _send_stream_response_headers(
send, protocol, codec, resp_compression.name(), ctx
try:
async for message in response_stream:
if disconnect_detected is not None and disconnect_detected.is_set():
raise ConnectError(Code.CANCELED, "Client disconnected")
# Don't send headers until the first message to allow logic a chance to add
# response headers.
if not sent_headers:
await _send_stream_response_headers(
send, protocol, codec, resp_compression.name(), ctx
)
sent_headers = True

body = writer.write(message)
await send(
{"type": "http.response.body", "body": body, "more_body": True}
)
sent_headers = True

body = writer.write(message)
await send(
{"type": "http.response.body", "body": body, "more_body": True}
)
finally:
# Cancel the monitor first so a throwing generator finally-block
# doesn't leak the task.
if monitor_task is not None:
monitor_task.cancel()
with contextlib.suppress(CancelledError):
await monitor_task
# Explicitly close the stream so that any generator finally-blocks
# run promptly (Python defers async-generator cleanup to GC otherwise).
aclose = getattr(response_stream, "aclose", None)
if aclose is not None:
await aclose()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC it's possible for this to throw since it's user code and then monitor task would be leaked. When thinking about letting it propagate vs catch and maybe log, I guess a finally in a user generator is still part of the request handler and makes sense to allow to affect the response. How about just reordering then if that makes sense?

except CancelledError as e:
raise ConnectError(Code.CANCELED, "Request was cancelled") from e
except Exception as e:
Expand Down
77 changes: 77 additions & 0 deletions test/test_roundtrip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio
import struct
from typing import TYPE_CHECKING

import pytest
Expand All @@ -23,6 +25,8 @@
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator

from asgiref.typing import HTTPDisconnectEvent, HTTPRequestEvent, HTTPScope


@pytest.mark.parametrize("proto_json", [False, True])
@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"])
Expand Down Expand Up @@ -280,3 +284,76 @@ async def request_stream():
else:
assert len(requests) == 2
assert len(responses) == 1


@pytest.mark.asyncio
async def test_server_stream_client_disconnect() -> None:
"""Server streaming generator should be closed when the client disconnects.

Regression test for https://github.com/connectrpc/connect-python/issues/174.
"""
generator_closed = asyncio.Event()

class InfiniteHaberdasher(Haberdasher):
async def make_similar_hats(self, request, ctx):
try:
while True:
yield Hat(size=request.inches, color="green")
await asyncio.sleep(0) # yield control to event loop
finally:
generator_closed.set()

app = HaberdasherASGIApplication(InfiniteHaberdasher())

# Encode a Connect protocol (application/connect+proto) request for Size(inches=10).
request_bytes = Size(inches=10).SerializeToString()
request_body = struct.pack(">BI", 0, len(request_bytes)) + request_bytes

# We invoke the ASGI app directly rather than using a real client with a
# short timeout because a real client could trigger the disconnect before the
# request body has been fully read, which would be a different code path.
disconnect_trigger = asyncio.Event()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a client with a timeout of some tens of Ms should be non-flaky, but maybe the windows runner would try extra hard to prove otherwise.

Manually invoking the app seems fine too but let's add a comment then about potential flakiness if using a client with a timeout (we need it to happen after the request has been fully read)?

response_count = 0
call_count = 0

async def receive() -> HTTPRequestEvent | HTTPDisconnectEvent:
nonlocal call_count
call_count += 1
if call_count == 1:
return {"type": "http.request", "body": request_body, "more_body": False}
# Block until the test is ready to simulate a disconnect.
await disconnect_trigger.wait()
return {"type": "http.disconnect"}

async def send(message):
nonlocal response_count
if message.get("type") == "http.response.body" and message.get(
"more_body", False
):
response_count += 1
if response_count >= 3:
disconnect_trigger.set()

scope: HTTPScope = {
"type": "http",
"asgi": {"spec_version": "2.0", "version": "3.0"},
"http_version": "1.1",
"method": "POST",
"scheme": "http",
"path": "/connectrpc.example.Haberdasher/MakeSimilarHats",
"raw_path": b"/connectrpc.example.Haberdasher/MakeSimilarHats",
"query_string": b"",
"root_path": "",
"headers": [(b"content-type", b"application/connect+proto")],
"client": None,
"server": None,
"extensions": None,
}

# Without the fix the app hangs forever (generator never stopped), causing a
# TimeoutError here. With the fix it terminates promptly after the disconnect.
await asyncio.wait_for(app(scope, receive, send), timeout=5.0)

assert generator_closed.is_set(), (
"generator should be closed after client disconnect"
)
Loading