From 71868f95abea6c85a0a7ecf0ccbbd4471bfc2d40 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Mon, 4 May 2026 20:32:15 +0200 Subject: [PATCH 1/3] add: stream subscriber mapping --- README.md | 33 ++++++++ taskiq_aio_kafka/broker.py | 59 ++++++++++++- taskiq_aio_kafka/constants.py | 3 +- taskiq_aio_kafka/subscriber.py | 16 ++++ tests/test_broker_multi_topic.py | 138 +++++++++++++++++++++++++++++++ uv.lock | 1 - 6 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 taskiq_aio_kafka/subscriber.py diff --git a/README.md b/README.md index 79a052f..8a9dddc 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,39 @@ async def regular_task() -> None: await regular_task.kiq() ``` +## Stream topics + +You can subscribe an existing task to a Kafka topic with raw messages. +Messages from subscribed stream topics are wrapped into regular taskiq +messages before execution, so result backends and worker middlewares keep +working as usual. + +```python +import json + +from taskiq_aio_kafka import AioKafkaBroker + +broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="taskiq-topic", +) + + +@broker.task +async def process_user_created(event: dict[str, object]) -> None: + print(event) + + +broker.subscribe( + "users.created", + process_user_created, + decoder=json.loads, +) +``` + +The decoded value is passed as the first task argument. Stream messages also +receive the `taskiq-stream` label with the source topic name. + ## Configuration AioKafkaBroker parameters: diff --git a/taskiq_aio_kafka/broker.py b/taskiq_aio_kafka/broker.py index c9d4484..6e94a65 100644 --- a/taskiq_aio_kafka/broker.py +++ b/taskiq_aio_kafka/broker.py @@ -6,17 +6,21 @@ from typing import Any, TypeVar, overload from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka.structs import ConsumerRecord from kafka.admin import KafkaAdminClient, NewTopic from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.partitioner.default import DefaultPartitioner from taskiq import AsyncResultBackend, BrokerMessage from taskiq.abc.broker import AsyncBroker +from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.message import TaskiqMessage from typing_extensions import ParamSpec -from .constants import TASK_TOPIC_LABEL +from .constants import TASK_STREAM_LABEL, TASK_TOPIC_LABEL from .decorated_task import AioKafkaDecoratedTask from .exceptions import WrongAioKafkaBrokerParametersError from .models import KafkaConsumerParameters, KafkaProducerParameters +from .subscriber import StreamDecoder, StreamSubscriber from .topic import Topic from .types import TopicType from .utils import get_topic_name @@ -79,6 +83,7 @@ def __init__( # noqa: PLR0913 get_topic_name(topic), topic, ) + self._stream_subscribers: dict[str, StreamSubscriber] = {} self._aiokafka_producer_params: KafkaProducerParameters = ( KafkaProducerParameters() @@ -165,6 +170,30 @@ def configure_consumer(self, **consumer_parameters: Any) -> None: **consumer_parameters, ) + @staticmethod + def _default_stream_decoder(message: bytes) -> bytes: + return message + + def subscribe( + self, + topic: TopicType, + task: AsyncTaskiqDecoratedTask[Any, Any], + decoder: StreamDecoder | None = None, + **labels: Any, + ) -> None: + """Subscribe task to raw Kafka topic messages.""" + topic_name = get_topic_name(topic) + if topic_name in self._stream_subscribers: + error_message = f"Topic {topic_name!r} is already subscribed." + raise ValueError(error_message) + + self._kafka_topics.setdefault(topic_name, topic) + self._stream_subscribers[topic_name] = StreamSubscriber( + task_name=task.task_name, + decoder=decoder or self._default_stream_decoder, + labels=labels, + ) + @overload def task( self, @@ -347,4 +376,30 @@ async def listen( raise ValueError("Please run startup before listening.") async for raw_kafka_message in self._aiokafka_consumer: - yield raw_kafka_message.value + subscriber = self._stream_subscribers.get(raw_kafka_message.topic) + if subscriber is None: + yield raw_kafka_message.value + continue + + yield self._build_stream_message(raw_kafka_message, subscriber) + + def _build_stream_message( + self, + raw_kafka_message: ConsumerRecord[Any, bytes], + subscriber: StreamSubscriber, + ) -> bytes: + raw_value = raw_kafka_message.value + decoded_value = subscriber.decoder(raw_value) + labels = { + TASK_STREAM_LABEL: raw_kafka_message.topic, + **subscriber.labels, + } + message = TaskiqMessage( + task_id=self.id_generator(), + task_name=subscriber.task_name, + labels=labels, + labels_types={}, + args=[decoded_value], + kwargs={}, + ) + return self.formatter.dumps(message).message diff --git a/taskiq_aio_kafka/constants.py b/taskiq_aio_kafka/constants.py index 71f5fad..84f54ff 100644 --- a/taskiq_aio_kafka/constants.py +++ b/taskiq_aio_kafka/constants.py @@ -1,3 +1,4 @@ -__all__ = ("TASK_TOPIC_LABEL",) +__all__ = ("TASK_STREAM_LABEL", "TASK_TOPIC_LABEL") +TASK_STREAM_LABEL = "taskiq-stream" TASK_TOPIC_LABEL = "taskiq_aio_kafka_topic" diff --git a/taskiq_aio_kafka/subscriber.py b/taskiq_aio_kafka/subscriber.py new file mode 100644 index 0000000..19573ee --- /dev/null +++ b/taskiq_aio_kafka/subscriber.py @@ -0,0 +1,16 @@ +__all__ = ("StreamDecoder", "StreamSubscriber") + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +StreamDecoder = Callable[[bytes], Any] + + +@dataclass(frozen=True) +class StreamSubscriber: + """Kafka stream subscriber bound to a taskiq task.""" + + task_name: str + decoder: StreamDecoder + labels: dict[str, Any] = field(default_factory=dict) diff --git a/tests/test_broker_multi_topic.py b/tests/test_broker_multi_topic.py index 2e82911..64545d3 100644 --- a/tests/test_broker_multi_topic.py +++ b/tests/test_broker_multi_topic.py @@ -2,10 +2,12 @@ from unittest.mock import Mock import pytest +from aiokafka.structs import ConsumerRecord from kafka.admin import KafkaAdminClient, NewTopic from taskiq import BrokerMessage from taskiq_aio_kafka.broker import AioKafkaBroker +from taskiq_aio_kafka.constants import TASK_STREAM_LABEL from taskiq_aio_kafka.topic import Topic @@ -40,6 +42,23 @@ async def stop(self) -> None: """Stop consumer.""" +class _ConsumerMock: + """Kafka consumer mock.""" + + def __init__(self, messages: list[ConsumerRecord[None, bytes]]) -> None: + self._messages = iter(messages) + + def __aiter__(self) -> "_ConsumerMock": + return self + + async def __anext__(self) -> ConsumerRecord[None, bytes]: + """Return next message.""" + try: + return next(self._messages) + except StopIteration as exc: + raise StopAsyncIteration from exc + + def get_admin_client_mock() -> KafkaAdminClient: """Get kafka admin client mock.""" admin_client = Mock(spec=KafkaAdminClient) @@ -47,6 +66,29 @@ def get_admin_client_mock() -> KafkaAdminClient: return admin_client +def build_consumer_record(topic: str, value: bytes) -> ConsumerRecord[None, bytes]: + """Build Kafka consumer record.""" + return ConsumerRecord( + topic=topic, + partition=0, + offset=0, + timestamp=0, + timestamp_type=0, + key=None, + value=value, + checksum=None, + serialized_key_size=0, + serialized_value_size=len(value), + headers=[], + ) + + +async def get_first_task(broker: AioKafkaBroker) -> bytes: # type: ignore[return] + """Get first message from the topic.""" + async for message in broker.listen(): + return message + + async def test_task_topic_is_used_for_kick() -> None: """Test that task is sent to its declared topic.""" broker = AioKafkaBroker( @@ -252,3 +294,99 @@ def create_consumer( await broker.shutdown() assert consumer_topics == ("default-topic", "extra-topic") + + +async def test_subscribe_wraps_raw_topic_message() -> None: + """Test that raw Kafka messages are wrapped as taskiq messages.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: str) -> None: + assert message + + broker.subscribe( + "stream-topic", + test_task, + decoder=lambda message: message.decode(), + ) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.kwargs == {} + assert taskiq_message.labels[TASK_STREAM_LABEL] == "stream-topic" + + +async def test_subscribe_accepts_topic_object_and_custom_labels() -> None: + """Test that stream subscriptions can use Topic objects.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + stream_topic = Topic("stream-topic") + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe(stream_topic, test_task, source="external") + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record(stream_topic.name, b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.labels == { + TASK_STREAM_LABEL: stream_topic.name, + "source": "external", + } + + +def test_subscribe_adds_topic_to_broker_topics() -> None: + """Test that subscribed topic is included in listened topics.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + + assert set(broker._kafka_topics) == {"default-topic", "stream-topic"} + + +def test_subscribe_rejects_duplicate_topics() -> None: + """Test that stream topic cannot be subscribed twice.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + + with pytest.raises(ValueError, match="already subscribed"): + broker.subscribe("stream-topic", test_task) diff --git a/uv.lock b/uv.lock index 7533434..1abc153 100644 --- a/uv.lock +++ b/uv.lock @@ -1347,7 +1347,6 @@ wheels = [ [[package]] name = "taskiq-aio-kafka" -version = "0.2.1" source = { editable = "." } dependencies = [ { name = "aiokafka" }, From 2d81c8152b01ac0af768e19e0cc7ed50e59bbf26 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Mon, 4 May 2026 20:49:06 +0200 Subject: [PATCH 2/3] add: stream data normalizer --- README.md | 29 ++++++++++++++++++++-- taskiq_aio_kafka/__init__.py | 3 ++- taskiq_aio_kafka/broker.py | 17 +++++++++---- taskiq_aio_kafka/subscriber.py | 15 +++++++++--- tests/test_broker_multi_topic.py | 42 ++++++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 8a9dddc..78fa206 100644 --- a/README.md +++ b/README.md @@ -117,8 +117,33 @@ broker.subscribe( ) ``` -The decoded value is passed as the first task argument. Stream messages also -receive the `taskiq-stream` label with the source topic name. +The decoded value is passed as the first task argument. A decoder can also +return `StreamMessage` to map one Kafka message into the exact `args` and +`kwargs` expected by the task. + +```python +import json + +from taskiq_aio_kafka import StreamMessage + + +@broker.task +async def process_user(user_id: int, email: str) -> None: + print(user_id, email) + + +def decode_user_event(message: bytes) -> StreamMessage: + event = json.loads(message) + return StreamMessage( + args=(event["user"]["id"],), + kwargs={"email": event["user"]["email"]}, + ) + + +broker.subscribe("users.created", process_user, decoder=decode_user_event) +``` + +Stream messages also receive the `taskiq-stream` label with the source topic name. ## Configuration diff --git a/taskiq_aio_kafka/__init__.py b/taskiq_aio_kafka/__init__.py index a452231..20e1fe2 100644 --- a/taskiq_aio_kafka/__init__.py +++ b/taskiq_aio_kafka/__init__.py @@ -1,5 +1,6 @@ """Taskiq integration with aiokafka.""" -__all__ = ("AioKafkaBroker",) +__all__ = ("AioKafkaBroker", "StreamMessage") from taskiq_aio_kafka.broker import AioKafkaBroker +from taskiq_aio_kafka.subscriber import StreamMessage diff --git a/taskiq_aio_kafka/broker.py b/taskiq_aio_kafka/broker.py index 6e94a65..c97c82b 100644 --- a/taskiq_aio_kafka/broker.py +++ b/taskiq_aio_kafka/broker.py @@ -20,7 +20,7 @@ from .decorated_task import AioKafkaDecoratedTask from .exceptions import WrongAioKafkaBrokerParametersError from .models import KafkaConsumerParameters, KafkaProducerParameters -from .subscriber import StreamDecoder, StreamSubscriber +from .subscriber import StreamDecoder, StreamMessage, StreamSubscriber from .topic import Topic from .types import TopicType from .utils import get_topic_name @@ -171,8 +171,8 @@ def configure_consumer(self, **consumer_parameters: Any) -> None: ) @staticmethod - def _default_stream_decoder(message: bytes) -> bytes: - return message + def _default_stream_decoder(message: bytes) -> StreamMessage: + return StreamMessage(args=(message,)) def subscribe( self, @@ -390,6 +390,7 @@ def _build_stream_message( ) -> bytes: raw_value = raw_kafka_message.value decoded_value = subscriber.decoder(raw_value) + stream_message = self._normalize_stream_message(decoded_value) labels = { TASK_STREAM_LABEL: raw_kafka_message.topic, **subscriber.labels, @@ -399,7 +400,13 @@ def _build_stream_message( task_name=subscriber.task_name, labels=labels, labels_types={}, - args=[decoded_value], - kwargs={}, + args=list(stream_message.args), + kwargs=stream_message.kwargs, ) return self.formatter.dumps(message).message + + @staticmethod + def _normalize_stream_message(message: Any | StreamMessage) -> StreamMessage: + if isinstance(message, StreamMessage): + return message + return StreamMessage(args=(message,)) diff --git a/taskiq_aio_kafka/subscriber.py b/taskiq_aio_kafka/subscriber.py index 19573ee..5b8e72c 100644 --- a/taskiq_aio_kafka/subscriber.py +++ b/taskiq_aio_kafka/subscriber.py @@ -1,10 +1,19 @@ -__all__ = ("StreamDecoder", "StreamSubscriber") +__all__ = ("StreamDecoder", "StreamMessage", "StreamSubscriber") -from collections.abc import Callable +from collections.abc import Callable, Sequence from dataclasses import dataclass, field from typing import Any -StreamDecoder = Callable[[bytes], Any] + +@dataclass(frozen=True) +class StreamMessage: + """Decoded stream message arguments for a taskiq task.""" + + args: Sequence[Any] = () + kwargs: dict[str, Any] = field(default_factory=dict) + + +StreamDecoder = Callable[[bytes], Any | StreamMessage] @dataclass(frozen=True) diff --git a/tests/test_broker_multi_topic.py b/tests/test_broker_multi_topic.py index 64545d3..afd1386 100644 --- a/tests/test_broker_multi_topic.py +++ b/tests/test_broker_multi_topic.py @@ -8,6 +8,7 @@ from taskiq_aio_kafka.broker import AioKafkaBroker from taskiq_aio_kafka.constants import TASK_STREAM_LABEL +from taskiq_aio_kafka.subscriber import StreamMessage from taskiq_aio_kafka.topic import Topic @@ -357,6 +358,47 @@ async def test_task(message: bytes) -> None: } +async def test_subscribe_decoder_can_return_task_args_and_kwargs() -> None: + """Test that stream messages can be mapped to task args and kwargs.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(user_id: int, email: str, active: bool) -> None: + assert user_id + assert email + assert active + + def decode_message(message: bytes) -> StreamMessage: + user_id, email, active = message.decode().split(":") + return StreamMessage( + args=(int(user_id),), + kwargs={ + "email": email, + "active": active == "true", + }, + ) + + broker.subscribe("stream-topic", test_task, decoder=decode_message) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"1:user@example.com:true")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == [1] + assert taskiq_message.kwargs == { + "email": "user@example.com", + "active": True, + } + + def test_subscribe_adds_topic_to_broker_topics() -> None: """Test that subscribed topic is included in listened topics.""" broker = AioKafkaBroker( From b0c96a51da68c4fbdf86039f1dd2ec47fd33894a Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Mon, 4 May 2026 20:55:52 +0200 Subject: [PATCH 3/3] update: tests --- taskiq_aio_kafka/broker.py | 2 +- tests/test_broker_multi_topic.py | 180 ---------------- tests/test_subscriber.py | 359 +++++++++++++++++++++++++++++++ 3 files changed, 360 insertions(+), 181 deletions(-) create mode 100644 tests/test_subscriber.py diff --git a/taskiq_aio_kafka/broker.py b/taskiq_aio_kafka/broker.py index c97c82b..8895fa5 100644 --- a/taskiq_aio_kafka/broker.py +++ b/taskiq_aio_kafka/broker.py @@ -392,8 +392,8 @@ def _build_stream_message( decoded_value = subscriber.decoder(raw_value) stream_message = self._normalize_stream_message(decoded_value) labels = { - TASK_STREAM_LABEL: raw_kafka_message.topic, **subscriber.labels, + TASK_STREAM_LABEL: raw_kafka_message.topic, } message = TaskiqMessage( task_id=self.id_generator(), diff --git a/tests/test_broker_multi_topic.py b/tests/test_broker_multi_topic.py index afd1386..2e82911 100644 --- a/tests/test_broker_multi_topic.py +++ b/tests/test_broker_multi_topic.py @@ -2,13 +2,10 @@ from unittest.mock import Mock import pytest -from aiokafka.structs import ConsumerRecord from kafka.admin import KafkaAdminClient, NewTopic from taskiq import BrokerMessage from taskiq_aio_kafka.broker import AioKafkaBroker -from taskiq_aio_kafka.constants import TASK_STREAM_LABEL -from taskiq_aio_kafka.subscriber import StreamMessage from taskiq_aio_kafka.topic import Topic @@ -43,23 +40,6 @@ async def stop(self) -> None: """Stop consumer.""" -class _ConsumerMock: - """Kafka consumer mock.""" - - def __init__(self, messages: list[ConsumerRecord[None, bytes]]) -> None: - self._messages = iter(messages) - - def __aiter__(self) -> "_ConsumerMock": - return self - - async def __anext__(self) -> ConsumerRecord[None, bytes]: - """Return next message.""" - try: - return next(self._messages) - except StopIteration as exc: - raise StopAsyncIteration from exc - - def get_admin_client_mock() -> KafkaAdminClient: """Get kafka admin client mock.""" admin_client = Mock(spec=KafkaAdminClient) @@ -67,29 +47,6 @@ def get_admin_client_mock() -> KafkaAdminClient: return admin_client -def build_consumer_record(topic: str, value: bytes) -> ConsumerRecord[None, bytes]: - """Build Kafka consumer record.""" - return ConsumerRecord( - topic=topic, - partition=0, - offset=0, - timestamp=0, - timestamp_type=0, - key=None, - value=value, - checksum=None, - serialized_key_size=0, - serialized_value_size=len(value), - headers=[], - ) - - -async def get_first_task(broker: AioKafkaBroker) -> bytes: # type: ignore[return] - """Get first message from the topic.""" - async for message in broker.listen(): - return message - - async def test_task_topic_is_used_for_kick() -> None: """Test that task is sent to its declared topic.""" broker = AioKafkaBroker( @@ -295,140 +252,3 @@ def create_consumer( await broker.shutdown() assert consumer_topics == ("default-topic", "extra-topic") - - -async def test_subscribe_wraps_raw_topic_message() -> None: - """Test that raw Kafka messages are wrapped as taskiq messages.""" - broker = AioKafkaBroker( - bootstrap_servers="localhost", - kafka_topic="default-topic", - kafka_admin_client=get_admin_client_mock(), - ) - - @broker.task - async def test_task(message: str) -> None: - assert message - - broker.subscribe( - "stream-topic", - test_task, - decoder=lambda message: message.decode(), - ) - broker._aiokafka_consumer = _ConsumerMock( - [build_consumer_record("stream-topic", b"raw-message")], - ) - broker._is_consumer_started = True - - received_message = await get_first_task(broker) - taskiq_message = broker.formatter.loads(received_message) - - assert taskiq_message.task_name == test_task.task_name - assert taskiq_message.args == ["raw-message"] - assert taskiq_message.kwargs == {} - assert taskiq_message.labels[TASK_STREAM_LABEL] == "stream-topic" - - -async def test_subscribe_accepts_topic_object_and_custom_labels() -> None: - """Test that stream subscriptions can use Topic objects.""" - broker = AioKafkaBroker( - bootstrap_servers="localhost", - kafka_topic="default-topic", - kafka_admin_client=get_admin_client_mock(), - ) - stream_topic = Topic("stream-topic") - - @broker.task - async def test_task(message: bytes) -> None: - assert message - - broker.subscribe(stream_topic, test_task, source="external") - broker._aiokafka_consumer = _ConsumerMock( - [build_consumer_record(stream_topic.name, b"raw-message")], - ) - broker._is_consumer_started = True - - received_message = await get_first_task(broker) - taskiq_message = broker.formatter.loads(received_message) - - assert taskiq_message.task_name == test_task.task_name - assert taskiq_message.args == ["raw-message"] - assert taskiq_message.labels == { - TASK_STREAM_LABEL: stream_topic.name, - "source": "external", - } - - -async def test_subscribe_decoder_can_return_task_args_and_kwargs() -> None: - """Test that stream messages can be mapped to task args and kwargs.""" - broker = AioKafkaBroker( - bootstrap_servers="localhost", - kafka_topic="default-topic", - kafka_admin_client=get_admin_client_mock(), - ) - - @broker.task - async def test_task(user_id: int, email: str, active: bool) -> None: - assert user_id - assert email - assert active - - def decode_message(message: bytes) -> StreamMessage: - user_id, email, active = message.decode().split(":") - return StreamMessage( - args=(int(user_id),), - kwargs={ - "email": email, - "active": active == "true", - }, - ) - - broker.subscribe("stream-topic", test_task, decoder=decode_message) - broker._aiokafka_consumer = _ConsumerMock( - [build_consumer_record("stream-topic", b"1:user@example.com:true")], - ) - broker._is_consumer_started = True - - received_message = await get_first_task(broker) - taskiq_message = broker.formatter.loads(received_message) - - assert taskiq_message.task_name == test_task.task_name - assert taskiq_message.args == [1] - assert taskiq_message.kwargs == { - "email": "user@example.com", - "active": True, - } - - -def test_subscribe_adds_topic_to_broker_topics() -> None: - """Test that subscribed topic is included in listened topics.""" - broker = AioKafkaBroker( - bootstrap_servers="localhost", - kafka_topic="default-topic", - kafka_admin_client=get_admin_client_mock(), - ) - - @broker.task - async def test_task(message: bytes) -> None: - assert message - - broker.subscribe("stream-topic", test_task) - - assert set(broker._kafka_topics) == {"default-topic", "stream-topic"} - - -def test_subscribe_rejects_duplicate_topics() -> None: - """Test that stream topic cannot be subscribed twice.""" - broker = AioKafkaBroker( - bootstrap_servers="localhost", - kafka_topic="default-topic", - kafka_admin_client=get_admin_client_mock(), - ) - - @broker.task - async def test_task(message: bytes) -> None: - assert message - - broker.subscribe("stream-topic", test_task) - - with pytest.raises(ValueError, match="already subscribed"): - broker.subscribe("stream-topic", test_task) diff --git a/tests/test_subscriber.py b/tests/test_subscriber.py new file mode 100644 index 0000000..0dbd7d0 --- /dev/null +++ b/tests/test_subscriber.py @@ -0,0 +1,359 @@ +from unittest.mock import Mock + +import pytest +from aiokafka.structs import ConsumerRecord +from kafka.admin import KafkaAdminClient + +from taskiq_aio_kafka.broker import AioKafkaBroker +from taskiq_aio_kafka.constants import TASK_STREAM_LABEL +from taskiq_aio_kafka.subscriber import StreamMessage +from taskiq_aio_kafka.topic import Topic + + +class _ConsumerMock: + """Kafka consumer mock.""" + + def __init__(self, messages: list[ConsumerRecord[None, bytes]]) -> None: + self._messages = iter(messages) + + def __aiter__(self) -> "_ConsumerMock": + return self + + async def __anext__(self) -> ConsumerRecord[None, bytes]: + """Return next message.""" + try: + return next(self._messages) + except StopIteration as exc: + raise StopAsyncIteration from exc + + +class _ProducerStartStopMock: + """Kafka producer lifecycle mock.""" + + async def start(self) -> None: + """Start producer.""" + + async def stop(self) -> None: + """Stop producer.""" + + +class _ConsumerStartStopMock: + """Kafka consumer lifecycle mock.""" + + async def start(self) -> None: + """Start consumer.""" + + async def stop(self) -> None: + """Stop consumer.""" + + +def get_admin_client_mock() -> KafkaAdminClient: + """Get kafka admin client mock.""" + admin_client = Mock(spec=KafkaAdminClient) + admin_client.list_topics.return_value = [] + return admin_client + + +def build_consumer_record(topic: str, value: bytes) -> ConsumerRecord[None, bytes]: + """Build Kafka consumer record.""" + return ConsumerRecord( + topic=topic, + partition=0, + offset=0, + timestamp=0, + timestamp_type=0, + key=None, + value=value, + checksum=None, + serialized_key_size=0, + serialized_value_size=len(value), + headers=[], + ) + + +async def get_first_task(broker: AioKafkaBroker) -> bytes: # type: ignore[return] + """Get first message from the topic.""" + async for message in broker.listen(): + return message + + +async def test_subscribe_wraps_raw_topic_message() -> None: + """Test that raw Kafka messages are wrapped as taskiq messages.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: str) -> None: + assert message + + broker.subscribe( + "stream-topic", + test_task, + decoder=lambda message: message.decode(), + ) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.kwargs == {} + assert taskiq_message.labels[TASK_STREAM_LABEL] == "stream-topic" + + +async def test_subscribe_without_decoder_passes_raw_bytes() -> None: + """Test that default stream decoder passes raw bytes.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.kwargs == {} + + +async def test_subscribe_accepts_topic_object_and_custom_labels() -> None: + """Test that stream subscriptions can use Topic objects.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + stream_topic = Topic("stream-topic") + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe(stream_topic, test_task, source="external") + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record(stream_topic.name, b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.labels == { + TASK_STREAM_LABEL: stream_topic.name, + "source": "external", + } + + +async def test_subscribe_keeps_stream_label_from_source_topic() -> None: + """Test that custom labels cannot override source stream label.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe( + "stream-topic", + test_task, + decoder=None, + **{TASK_STREAM_LABEL: "wrong-topic"}, + ) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.labels[TASK_STREAM_LABEL] == "stream-topic" + + +async def test_subscribe_decoder_can_return_task_args_and_kwargs() -> None: + """Test that stream messages can be mapped to task args and kwargs.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(user_id: int, email: str, active: bool) -> None: + assert user_id + assert email + assert active + + def decode_message(message: bytes) -> StreamMessage: + user_id, email, active = message.decode().split(":") + return StreamMessage( + args=(int(user_id),), + kwargs={ + "email": email, + "active": active == "true", + }, + ) + + broker.subscribe("stream-topic", test_task, decoder=decode_message) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"1:user@example.com:true")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == [1] + assert taskiq_message.kwargs == { + "email": "user@example.com", + "active": True, + } + + +async def test_subscribe_uses_broker_task_id_generator() -> None: + """Test that wrapped stream messages use broker task id generator.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ).with_id_generator(lambda: "stream-task-id") + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_id == "stream-task-id" + + +async def test_listen_keeps_regular_topic_messages_with_subscribers() -> None: + """Test that non-stream topics keep yielding raw taskiq messages.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("default-topic", b"taskiq-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + + assert received_message == b"taskiq-message" + + +async def test_startup_subscribes_consumer_to_stream_topic( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that worker consumer subscribes to stream topics.""" + consumer_topics: tuple[str, ...] = () + + def create_producer(**_kwargs: object) -> _ProducerStartStopMock: + return _ProducerStartStopMock() + + def create_consumer( + *topics: str, + **_kwargs: object, + ) -> _ConsumerStartStopMock: + nonlocal consumer_topics + consumer_topics = topics + return _ConsumerStartStopMock() + + monkeypatch.setattr( + "taskiq_aio_kafka.broker.AIOKafkaProducer", + create_producer, + ) + monkeypatch.setattr( + "taskiq_aio_kafka.broker.AIOKafkaConsumer", + create_consumer, + ) + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker.is_worker_process = True + + await broker.startup() + await broker.shutdown() + + assert consumer_topics == ("default-topic", "stream-topic") + + +def test_subscribe_adds_topic_to_broker_topics() -> None: + """Test that subscribed topic is included in listened topics.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + + assert set(broker._kafka_topics) == {"default-topic", "stream-topic"} + + +def test_subscribe_rejects_duplicate_topics() -> None: + """Test that stream topic cannot be subscribed twice.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + + with pytest.raises(ValueError, match="already subscribed"): + broker.subscribe("stream-topic", test_task)