diff --git a/tests/test_task_dispatcher.py b/tests/test_task_dispatcher.py new file mode 100644 index 0000000..9ad0940 --- /dev/null +++ b/tests/test_task_dispatcher.py @@ -0,0 +1,205 @@ +"""Tests for DefaultMessageDispatcher — request/notification routing.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio + +from acp.task import RpcTask, RpcTaskKind +from acp.task.dispatcher import DefaultMessageDispatcher +from acp.task.queue import InMemoryMessageQueue +from acp.task.state import InMemoryMessageStateStore +from acp.task.supervisor import TaskSupervisor + + +@pytest_asyncio.fixture +async def supervisor() -> AsyncGenerator[TaskSupervisor, None]: + sup = TaskSupervisor(source="test") + sup.add_error_handler(lambda _t, _e: None) + yield sup + await sup.shutdown() + + +@pytest.fixture +def store() -> InMemoryMessageStateStore: + return InMemoryMessageStateStore() + + +@pytest_asyncio.fixture +async def queue() -> AsyncGenerator[InMemoryMessageQueue, None]: + q = InMemoryMessageQueue() + yield q + await q.close() + + +@pytest.mark.asyncio +async def test_dispatch_request( + supervisor: TaskSupervisor, + store: InMemoryMessageStateStore, + queue: InMemoryMessageQueue, +) -> None: + """Dispatcher should route REQUEST tasks to the request runner.""" + results: list[dict] = [] + + async def request_runner(msg: dict) -> dict: + results.append(msg) + return {"ok": True} + + async def notification_runner(msg: dict) -> None: + pass + + dispatcher = DefaultMessageDispatcher( + queue=queue, + supervisor=supervisor, + store=store, + request_runner=request_runner, + notification_runner=notification_runner, + ) + dispatcher.start() + + await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "test/req", "params": {}})) + await asyncio.sleep(0.1) + await queue.close() + await dispatcher.stop() + await supervisor.shutdown() + + assert len(results) == 1 + assert results[0]["method"] == "test/req" + + +@pytest.mark.asyncio +async def test_dispatch_notification( + supervisor: TaskSupervisor, + store: InMemoryMessageStateStore, + queue: InMemoryMessageQueue, +) -> None: + """Dispatcher should route NOTIFICATION tasks to the notification runner.""" + notifications: list[dict] = [] + + async def request_runner(msg: dict) -> dict: + return {} + + async def notification_runner(msg: dict) -> None: + notifications.append(msg) + + dispatcher = DefaultMessageDispatcher( + queue=queue, + supervisor=supervisor, + store=store, + request_runner=request_runner, + notification_runner=notification_runner, + ) + dispatcher.start() + + await queue.publish(RpcTask(kind=RpcTaskKind.NOTIFICATION, message={"method": "session/update"})) + await asyncio.sleep(0.1) + await queue.close() + await dispatcher.stop() + await supervisor.shutdown() + + assert len(notifications) == 1 + assert notifications[0]["method"] == "session/update" + + +@pytest.mark.asyncio +async def test_start_twice_raises( + supervisor: TaskSupervisor, + store: InMemoryMessageStateStore, + queue: InMemoryMessageQueue, +) -> None: + """Starting the dispatcher twice should raise.""" + + async def noop(msg: dict) -> None: + pass + + dispatcher = DefaultMessageDispatcher( + queue=queue, + supervisor=supervisor, + store=store, + request_runner=noop, + notification_runner=noop, + ) + dispatcher.start() + + with pytest.raises(RuntimeError, match="already started"): + dispatcher.start() + + await queue.close() + await dispatcher.stop() + await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_failed_request_updates_store( + supervisor: TaskSupervisor, + store: InMemoryMessageStateStore, + queue: InMemoryMessageQueue, +) -> None: + """When the request runner raises, the store should record the failure.""" + + async def failing_runner(msg: dict) -> dict: + raise ValueError("handler error") + + async def notification_runner(msg: dict) -> None: + pass + + dispatcher = DefaultMessageDispatcher( + queue=queue, + supervisor=supervisor, + store=store, + request_runner=failing_runner, + notification_runner=notification_runner, + ) + dispatcher.start() + + await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "test/fail", "params": None})) + await asyncio.sleep(0.15) + await queue.close() + await dispatcher.stop() + await supervisor.shutdown() + + # NOTE: Accessing private state because InMemoryMessageStateStore has no + # public API to query incoming records. + assert len(store._incoming) == 1 + assert store._incoming[0].status == "failed" + assert isinstance(store._incoming[0].error, ValueError) + + +@pytest.mark.asyncio +async def test_multiple_tasks_dispatched( + supervisor: TaskSupervisor, + store: InMemoryMessageStateStore, + queue: InMemoryMessageQueue, +) -> None: + """Multiple tasks should all be dispatched and processed.""" + processed: list[str] = [] + + async def request_runner(msg: dict) -> dict: + processed.append(msg["method"]) + return {} + + async def notification_runner(msg: dict) -> None: + processed.append(msg["method"]) + + dispatcher = DefaultMessageDispatcher( + queue=queue, + supervisor=supervisor, + store=store, + request_runner=request_runner, + notification_runner=notification_runner, + ) + dispatcher.start() + + await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "r1"})) + await queue.publish(RpcTask(kind=RpcTaskKind.NOTIFICATION, message={"method": "n1"})) + await queue.publish(RpcTask(kind=RpcTaskKind.REQUEST, message={"method": "r2"})) + + await asyncio.sleep(0.15) + await queue.close() + await dispatcher.stop() + await supervisor.shutdown() + + assert sorted(processed) == ["n1", "r1", "r2"] diff --git a/tests/test_task_queue.py b/tests/test_task_queue.py new file mode 100644 index 0000000..d483ac0 --- /dev/null +++ b/tests/test_task_queue.py @@ -0,0 +1,93 @@ +"""Tests for InMemoryMessageQueue — publish, consume, close semantics.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from acp.task import RpcTask, RpcTaskKind +from acp.task.queue import InMemoryMessageQueue + + +def _make_task(method: str = "test/ping") -> RpcTask: + return RpcTask(kind=RpcTaskKind.REQUEST, message={"method": method}) + + +@pytest.mark.asyncio +async def test_publish_and_consume() -> None: + """Published tasks should be yielded by the async iterator.""" + queue = InMemoryMessageQueue() + task = _make_task("m1") + await queue.publish(task) + await queue.close() + + collected: list[RpcTask] = [] + async for item in queue: + collected.append(item) + + assert len(collected) == 1 + assert collected[0].message["method"] == "m1" + + +@pytest.mark.asyncio +async def test_fifo_ordering() -> None: + """Tasks should be consumed in FIFO order.""" + queue = InMemoryMessageQueue() + for i in range(5): + await queue.publish(_make_task(f"m{i}")) + await queue.close() + + methods: list[str] = [] + async for item in queue: + methods.append(item.message["method"]) + + assert methods == [f"m{i}" for i in range(5)] + + +@pytest.mark.asyncio +async def test_publish_after_close_raises() -> None: + """Publishing to a closed queue should raise RuntimeError.""" + queue = InMemoryMessageQueue() + await queue.close() + + with pytest.raises(RuntimeError, match=r"m[es]*sage queue already closed"): + await queue.publish(_make_task()) + + +@pytest.mark.asyncio +async def test_close_idempotent() -> None: + """Closing an already-closed queue should not raise.""" + queue = InMemoryMessageQueue() + await queue.close() + await queue.close() + + +@pytest.mark.asyncio +async def test_task_done_without_get_is_safe() -> None: + """task_done on an empty queue should not raise (suppresses ValueError).""" + queue = InMemoryMessageQueue() + queue.task_done() # should not raise + + +@pytest.mark.asyncio +async def test_join_waits_for_task_done() -> None: + """join() should block until all consumed tasks call task_done().""" + queue = InMemoryMessageQueue() + await queue.publish(_make_task()) + + joined = False + + async def consumer() -> None: + nonlocal joined + async for _ in queue: + queue.task_done() + joined = True + + async def closer() -> None: + await asyncio.sleep(0.05) + await queue.close() + + await asyncio.gather(consumer(), closer()) + await queue.join() + assert joined diff --git a/tests/test_task_state.py b/tests/test_task_state.py new file mode 100644 index 0000000..3ca1460 --- /dev/null +++ b/tests/test_task_state.py @@ -0,0 +1,96 @@ +"""Tests for InMemoryMessageStateStore — outgoing/incoming message state tracking.""" + +from __future__ import annotations + +import pytest + +from acp.task.state import InMemoryMessageStateStore + + +@pytest.fixture +def store() -> InMemoryMessageStateStore: + return InMemoryMessageStateStore() + + +@pytest.mark.asyncio +async def test_register_and_resolve_outgoing(store: InMemoryMessageStateStore) -> None: + """Resolving an outgoing request should fulfill its future.""" + future = store.register_outgoing(1, "session/initialize") + store.resolve_outgoing(1, {"ok": True}) + result = await future + assert result == {"ok": True} + + +@pytest.mark.asyncio +async def test_reject_outgoing(store: InMemoryMessageStateStore) -> None: + """Rejecting an outgoing request should set an exception on its future.""" + future = store.register_outgoing(2, "session/update") + error = Exception("server error") + store.reject_outgoing(2, error) + with pytest.raises(Exception, match="server error"): + await future + + +@pytest.mark.asyncio +async def test_resolve_unknown_id_is_noop(store: InMemoryMessageStateStore) -> None: + """Resolving a non-existent request ID should not raise.""" + store.resolve_outgoing(999, "ignored") + + +@pytest.mark.asyncio +async def test_reject_unknown_id_is_noop(store: InMemoryMessageStateStore) -> None: + """Rejecting a non-existent request ID should not raise.""" + store.reject_outgoing(999, Exception("ignored")) + + +@pytest.mark.asyncio +async def test_reject_all_outgoing(store: InMemoryMessageStateStore) -> None: + """reject_all_outgoing should fail every pending future.""" + f1 = store.register_outgoing(1, "m1") + f2 = store.register_outgoing(2, "m2") + f3 = store.register_outgoing(3, "m3") + + store.reject_all_outgoing(ConnectionError("disconnected")) + + for future in [f1, f2, f3]: + with pytest.raises(ConnectionError, match="disconnected"): + await future + + +@pytest.mark.asyncio +async def test_reject_all_skips_already_done(store: InMemoryMessageStateStore) -> None: + """reject_all should skip futures that are already resolved.""" + f1 = store.register_outgoing(1, "m1") + store.resolve_outgoing(1, "done") + + f2 = store.register_outgoing(2, "m2") + + store.reject_all_outgoing(ConnectionError("dc")) + + assert await f1 == "done" + with pytest.raises(ConnectionError): + await f2 + + +def test_begin_incoming(store: InMemoryMessageStateStore) -> None: + """begin_incoming should create a pending record.""" + record = store.begin_incoming("task/update", {"id": "t1"}) + assert record.method == "task/update" + assert record.params == {"id": "t1"} + assert record.status == "pending" + + +def test_complete_incoming(store: InMemoryMessageStateStore) -> None: + """complete_incoming should update the record status and result.""" + record = store.begin_incoming("task/update", {}) + store.complete_incoming(record, {"success": True}) + assert record.status == "completed" + assert record.result == {"success": True} + + +def test_fail_incoming(store: InMemoryMessageStateStore) -> None: + """fail_incoming should update the record status and error.""" + record = store.begin_incoming("task/update", {}) + store.fail_incoming(record, {"code": -32603}) + assert record.status == "failed" + assert record.error == {"code": -32603} diff --git a/tests/test_task_supervisor.py b/tests/test_task_supervisor.py new file mode 100644 index 0000000..59fed62 --- /dev/null +++ b/tests/test_task_supervisor.py @@ -0,0 +1,129 @@ +"""Tests for TaskSupervisor lifecycle and error handling.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio + +from acp.task.supervisor import TaskSupervisor + + +@pytest_asyncio.fixture +async def supervisor() -> AsyncGenerator[TaskSupervisor, None]: + sup = TaskSupervisor(source="test") + yield sup + await sup.shutdown() + + +@pytest.mark.asyncio +async def test_create_and_run_task(supervisor: TaskSupervisor) -> None: + """Created tasks should execute and complete.""" + result: list[int] = [] + + async def worker() -> None: + result.append(42) + + task = supervisor.create(worker(), name="test-worker") + await task + assert result == [42] + + +@pytest.mark.asyncio +async def test_create_raises_after_shutdown(supervisor: TaskSupervisor) -> None: + """Cannot create tasks after supervisor is shut down.""" + await supervisor.shutdown() + + coro = asyncio.sleep(0) + with pytest.raises(RuntimeError, match="already closed"): + supervisor.create(coro, name="late") + coro.close() + + +@pytest.mark.asyncio +async def test_shutdown_cancels_running_tasks(supervisor: TaskSupervisor) -> None: + """Shutdown should cancel all in-flight tasks.""" + started = asyncio.Event() + + async def long_running() -> None: + started.set() + await asyncio.sleep(999) + + task = supervisor.create(long_running(), name="long") + await started.wait() + + await supervisor.shutdown() + assert task.cancelled() + + +@pytest.mark.asyncio +async def test_shutdown_idempotent(supervisor: TaskSupervisor) -> None: + """Multiple shutdown calls should not raise.""" + await supervisor.shutdown() + await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_task_specific_error_handler(supervisor: TaskSupervisor) -> None: + """Per-task on_error callback should receive the exception.""" + captured: list[BaseException] = [] + + def on_error(_task: asyncio.Task, exc: BaseException) -> None: + captured.append(exc) + + async def failing() -> None: + raise ValueError("boom") + + supervisor.create(failing(), name="fail", on_error=on_error) + await asyncio.sleep(0.05) + + assert len(captured) == 1 + assert str(captured[0]) == "boom" + + +@pytest.mark.asyncio +async def test_global_error_handler_fallback(supervisor: TaskSupervisor) -> None: + """Global error handlers fire when no per-task handler is set.""" + captured: list[BaseException] = [] + supervisor.add_error_handler(lambda _t, exc: captured.append(exc)) + + async def failing() -> None: + raise RuntimeError("global-boom") + + supervisor.create(failing(), name="fail-global") + await asyncio.sleep(0.05) + + assert len(captured) == 1 + assert str(captured[0]) == "global-boom" + + +@pytest.mark.asyncio +async def test_completed_task_removed_from_registry(supervisor: TaskSupervisor) -> None: + """Finished tasks should be discarded from the internal set.""" + + async def quick() -> None: + pass + + supervisor.create(quick(), name="quick") + await asyncio.sleep(0.05) + + # After completion, shutdown should be instant (no tasks to cancel) + await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_cancelled_task_not_reported_as_error(supervisor: TaskSupervisor) -> None: + """Cancelled tasks should not trigger error handlers.""" + errors: list[BaseException] = [] + supervisor.add_error_handler(lambda _t, exc: errors.append(exc)) + + async def sleeper() -> None: + await asyncio.sleep(999) + + supervisor.create(sleeper(), name="sleeper") + await asyncio.sleep(0.01) + await supervisor.shutdown() + + assert errors == []