Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions tests/test_task_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Tests for DefaultMessageDispatcher — request/notification routing."""

from __future__ import annotations

import asyncio
from collections.abc import AsyncGenerator

import pytest
import pytest_asyncio

from acp.task import RpcTask, RpcTaskKind
from acp.task.dispatcher import DefaultMessageDispatcher
from acp.task.queue import InMemoryMessageQueue
from acp.task.state import InMemoryMessageStateStore
from acp.task.supervisor import TaskSupervisor


@pytest_asyncio.fixture
async def supervisor() -> AsyncGenerator[TaskSupervisor, None]:
sup = TaskSupervisor(source="test")
sup.add_error_handler(lambda _t, _e: None)
yield sup
await sup.shutdown()


@pytest.fixture
def store() -> InMemoryMessageStateStore:
return InMemoryMessageStateStore()


@pytest_asyncio.fixture
async def queue() -> AsyncGenerator[InMemoryMessageQueue, None]:
q = InMemoryMessageQueue()
yield q
await q.close()


@pytest.mark.asyncio
async def test_dispatch_request(
supervisor: TaskSupervisor,
store: InMemoryMessageStateStore,
queue: InMemoryMessageQueue,
) -> None:
"""Dispatcher should route REQUEST tasks to the request runner."""
results: list[dict] = []

async def request_runner(msg: dict) -> dict:
results.append(msg)
return {"ok": True}

async def notification_runner(msg: dict) -> None:
pass

dispatcher = DefaultMessageDispatcher(
queue=queue,
supervisor=supervisor,
store=store,
request_runner=request_runner,
notification_runner=notification_runner,
)
dispatcher.start()

await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "test/req", "params": {}}))
await asyncio.sleep(0.1)
await queue.close()
await dispatcher.stop()
await supervisor.shutdown()

assert len(results) == 1
assert results[0]["method"] == "test/req"


@pytest.mark.asyncio
async def test_dispatch_notification(
supervisor: TaskSupervisor,
store: InMemoryMessageStateStore,
queue: InMemoryMessageQueue,
) -> None:
"""Dispatcher should route NOTIFICATION tasks to the notification runner."""
notifications: list[dict] = []

async def request_runner(msg: dict) -> dict:
return {}

async def notification_runner(msg: dict) -> None:
notifications.append(msg)

dispatcher = DefaultMessageDispatcher(
queue=queue,
supervisor=supervisor,
store=store,
request_runner=request_runner,
notification_runner=notification_runner,
)
dispatcher.start()

await queue.publish(RpcTask(kind=RpcTaskKind.NOTIFICATION, message={"method": "session/update"}))
await asyncio.sleep(0.1)
await queue.close()
await dispatcher.stop()
await supervisor.shutdown()

assert len(notifications) == 1
assert notifications[0]["method"] == "session/update"


@pytest.mark.asyncio
async def test_start_twice_raises(
supervisor: TaskSupervisor,
store: InMemoryMessageStateStore,
queue: InMemoryMessageQueue,
) -> None:
"""Starting the dispatcher twice should raise."""

async def noop(msg: dict) -> None:
pass

dispatcher = DefaultMessageDispatcher(
queue=queue,
supervisor=supervisor,
store=store,
request_runner=noop,
notification_runner=noop,
)
dispatcher.start()

with pytest.raises(RuntimeError, match="already started"):
dispatcher.start()

await queue.close()
await dispatcher.stop()
await supervisor.shutdown()


@pytest.mark.asyncio
async def test_failed_request_updates_store(
supervisor: TaskSupervisor,
store: InMemoryMessageStateStore,
queue: InMemoryMessageQueue,
) -> None:
"""When the request runner raises, the store should record the failure."""

async def failing_runner(msg: dict) -> dict:
raise ValueError("handler error")

async def notification_runner(msg: dict) -> None:
pass

dispatcher = DefaultMessageDispatcher(
queue=queue,
supervisor=supervisor,
store=store,
request_runner=failing_runner,
notification_runner=notification_runner,
)
dispatcher.start()

await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "test/fail", "params": None}))
await asyncio.sleep(0.15)
await queue.close()
await dispatcher.stop()
await supervisor.shutdown()

# NOTE: Accessing private state because InMemoryMessageStateStore has no
# public API to query incoming records.
assert len(store._incoming) == 1
assert store._incoming[0].status == "failed"
assert isinstance(store._incoming[0].error, ValueError)


@pytest.mark.asyncio
async def test_multiple_tasks_dispatched(
supervisor: TaskSupervisor,
store: InMemoryMessageStateStore,
queue: InMemoryMessageQueue,
) -> None:
"""Multiple tasks should all be dispatched and processed."""
processed: list[str] = []

async def request_runner(msg: dict) -> dict:
processed.append(msg["method"])
return {}

async def notification_runner(msg: dict) -> None:
processed.append(msg["method"])

dispatcher = DefaultMessageDispatcher(
queue=queue,
supervisor=supervisor,
store=store,
request_runner=request_runner,
notification_runner=notification_runner,
)
dispatcher.start()

await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "r1"}))
await queue.publish(RpcTask(kind=RpcTaskKind.NOTIFICATION, message={"method": "n1"}))
await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "r2"}))

await asyncio.sleep(0.15)
await queue.close()
await dispatcher.stop()
await supervisor.shutdown()

assert sorted(processed) == ["n1", "r1", "r2"]
93 changes: 93 additions & 0 deletions tests/test_task_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Tests for InMemoryMessageQueue — publish, consume, close semantics."""

from __future__ import annotations

import asyncio

import pytest

from acp.task import RpcTask, RpcTaskKind
from acp.task.queue import InMemoryMessageQueue


def _make_task(method: str = "test/ping") -> RpcTask:
return RpcTask(kind=RpcTaskKind.REQUEST, message={"method": method})


@pytest.mark.asyncio
async def test_publish_and_consume() -> None:
"""Published tasks should be yielded by the async iterator."""
queue = InMemoryMessageQueue()
task = _make_task("m1")
await queue.publish(task)
await queue.close()

collected: list[RpcTask] = []
async for item in queue:
collected.append(item)

assert len(collected) == 1
assert collected[0].message["method"] == "m1"


@pytest.mark.asyncio
async def test_fifo_ordering() -> None:
"""Tasks should be consumed in FIFO order."""
queue = InMemoryMessageQueue()
for i in range(5):
await queue.publish(_make_task(f"m{i}"))
await queue.close()

methods: list[str] = []
async for item in queue:
methods.append(item.message["method"])

assert methods == [f"m{i}" for i in range(5)]


@pytest.mark.asyncio
async def test_publish_after_close_raises() -> None:
"""Publishing to a closed queue should raise RuntimeError."""
queue = InMemoryMessageQueue()
await queue.close()

with pytest.raises(RuntimeError, match=r"m[es]*sage queue already closed"):
await queue.publish(_make_task())


@pytest.mark.asyncio
async def test_close_idempotent() -> None:
"""Closing an already-closed queue should not raise."""
queue = InMemoryMessageQueue()
await queue.close()
await queue.close()


@pytest.mark.asyncio
async def test_task_done_without_get_is_safe() -> None:
"""task_done on an empty queue should not raise (suppresses ValueError)."""
queue = InMemoryMessageQueue()
queue.task_done() # should not raise


@pytest.mark.asyncio
async def test_join_waits_for_task_done() -> None:
"""join() should block until all consumed tasks call task_done()."""
queue = InMemoryMessageQueue()
await queue.publish(_make_task())

joined = False

async def consumer() -> None:
nonlocal joined
async for _ in queue:
queue.task_done()
joined = True

async def closer() -> None:
await asyncio.sleep(0.05)
await queue.close()

await asyncio.gather(consumer(), closer())
await queue.join()
assert joined
96 changes: 96 additions & 0 deletions tests/test_task_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Tests for InMemoryMessageStateStore — outgoing/incoming message state tracking."""

from __future__ import annotations

import pytest

from acp.task.state import InMemoryMessageStateStore


@pytest.fixture
def store() -> InMemoryMessageStateStore:
return InMemoryMessageStateStore()


@pytest.mark.asyncio
async def test_register_and_resolve_outgoing(store: InMemoryMessageStateStore) -> None:
"""Resolving an outgoing request should fulfill its future."""
future = store.register_outgoing(1, "session/initialize")
store.resolve_outgoing(1, {"ok": True})
result = await future
assert result == {"ok": True}


@pytest.mark.asyncio
async def test_reject_outgoing(store: InMemoryMessageStateStore) -> None:
"""Rejecting an outgoing request should set an exception on its future."""
future = store.register_outgoing(2, "session/update")
error = Exception("server error")
store.reject_outgoing(2, error)
with pytest.raises(Exception, match="server error"):
await future


@pytest.mark.asyncio
async def test_resolve_unknown_id_is_noop(store: InMemoryMessageStateStore) -> None:
"""Resolving a non-existent request ID should not raise."""
store.resolve_outgoing(999, "ignored")


@pytest.mark.asyncio
async def test_reject_unknown_id_is_noop(store: InMemoryMessageStateStore) -> None:
"""Rejecting a non-existent request ID should not raise."""
store.reject_outgoing(999, Exception("ignored"))


@pytest.mark.asyncio
async def test_reject_all_outgoing(store: InMemoryMessageStateStore) -> None:
"""reject_all_outgoing should fail every pending future."""
f1 = store.register_outgoing(1, "m1")
f2 = store.register_outgoing(2, "m2")
f3 = store.register_outgoing(3, "m3")

store.reject_all_outgoing(ConnectionError("disconnected"))

for future in [f1, f2, f3]:
with pytest.raises(ConnectionError, match="disconnected"):
await future


@pytest.mark.asyncio
async def test_reject_all_skips_already_done(store: InMemoryMessageStateStore) -> None:
"""reject_all should skip futures that are already resolved."""
f1 = store.register_outgoing(1, "m1")
store.resolve_outgoing(1, "done")

f2 = store.register_outgoing(2, "m2")

store.reject_all_outgoing(ConnectionError("dc"))

assert await f1 == "done"
with pytest.raises(ConnectionError):
await f2


def test_begin_incoming(store: InMemoryMessageStateStore) -> None:
"""begin_incoming should create a pending record."""
record = store.begin_incoming("task/update", {"id": "t1"})
assert record.method == "task/update"
assert record.params == {"id": "t1"}
assert record.status == "pending"


def test_complete_incoming(store: InMemoryMessageStateStore) -> None:
"""complete_incoming should update the record status and result."""
record = store.begin_incoming("task/update", {})
store.complete_incoming(record, {"success": True})
assert record.status == "completed"
assert record.result == {"success": True}


def test_fail_incoming(store: InMemoryMessageStateStore) -> None:
"""fail_incoming should update the record status and error."""
record = store.begin_incoming("task/update", {})
store.fail_incoming(record, {"code": -32603})
assert record.status == "failed"
assert record.error == {"code": -32603}
Loading