diff --git a/README.md b/README.md index 79a052f..78fa206 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,64 @@ async def regular_task() -> None: await regular_task.kiq() ``` +## Stream topics + +You can subscribe an existing task to a Kafka topic with raw messages. +Messages from subscribed stream topics are wrapped into regular taskiq +messages before execution, so result backends and worker middlewares keep +working as usual. + +```python +import json + +from taskiq_aio_kafka import AioKafkaBroker + +broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="taskiq-topic", +) + + +@broker.task +async def process_user_created(event: dict[str, object]) -> None: + print(event) + + +broker.subscribe( + "users.created", + process_user_created, + decoder=json.loads, +) +``` + +The decoded value is passed as the first task argument. A decoder can also +return `StreamMessage` to map one Kafka message into the exact `args` and +`kwargs` expected by the task. + +```python +import json + +from taskiq_aio_kafka import StreamMessage + + +@broker.task +async def process_user(user_id: int, email: str) -> None: + print(user_id, email) + + +def decode_user_event(message: bytes) -> StreamMessage: + event = json.loads(message) + return StreamMessage( + args=(event["user"]["id"],), + kwargs={"email": event["user"]["email"]}, + ) + + +broker.subscribe("users.created", process_user, decoder=decode_user_event) +``` + +Stream messages also receive the `taskiq-stream` label with the source topic name. + ## Configuration AioKafkaBroker parameters: diff --git a/taskiq_aio_kafka/__init__.py b/taskiq_aio_kafka/__init__.py index a452231..20e1fe2 100644 --- a/taskiq_aio_kafka/__init__.py +++ b/taskiq_aio_kafka/__init__.py @@ -1,5 +1,6 @@ """Taskiq integration with aiokafka.""" -__all__ = ("AioKafkaBroker",) +__all__ = ("AioKafkaBroker", "StreamMessage") from taskiq_aio_kafka.broker import AioKafkaBroker +from taskiq_aio_kafka.subscriber import StreamMessage diff --git a/taskiq_aio_kafka/broker.py b/taskiq_aio_kafka/broker.py index c9d4484..8895fa5 100644 --- a/taskiq_aio_kafka/broker.py +++ b/taskiq_aio_kafka/broker.py @@ -6,17 +6,21 @@ from typing import Any, TypeVar, overload from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka.structs import ConsumerRecord from kafka.admin import KafkaAdminClient, NewTopic from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.partitioner.default import DefaultPartitioner from taskiq import AsyncResultBackend, BrokerMessage from taskiq.abc.broker import AsyncBroker +from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.message import TaskiqMessage from typing_extensions import ParamSpec -from .constants import TASK_TOPIC_LABEL +from .constants import TASK_STREAM_LABEL, TASK_TOPIC_LABEL from .decorated_task import AioKafkaDecoratedTask from .exceptions import WrongAioKafkaBrokerParametersError from .models import KafkaConsumerParameters, KafkaProducerParameters +from .subscriber import StreamDecoder, StreamMessage, StreamSubscriber from .topic import Topic from .types import TopicType from .utils import get_topic_name @@ -79,6 +83,7 @@ def __init__( # noqa: PLR0913 get_topic_name(topic), topic, ) + self._stream_subscribers: dict[str, StreamSubscriber] = {} self._aiokafka_producer_params: KafkaProducerParameters = ( KafkaProducerParameters() @@ -165,6 +170,30 @@ def configure_consumer(self, **consumer_parameters: Any) -> None: **consumer_parameters, ) + @staticmethod + def _default_stream_decoder(message: bytes) -> StreamMessage: + return StreamMessage(args=(message,)) + + def subscribe( + self, + topic: TopicType, + task: AsyncTaskiqDecoratedTask[Any, Any], + decoder: StreamDecoder | None = None, + **labels: Any, + ) -> None: + """Subscribe task to raw Kafka topic messages.""" + topic_name = get_topic_name(topic) + if topic_name in self._stream_subscribers: + error_message = f"Topic {topic_name!r} is already subscribed." + raise ValueError(error_message) + + self._kafka_topics.setdefault(topic_name, topic) + self._stream_subscribers[topic_name] = StreamSubscriber( + task_name=task.task_name, + decoder=decoder or self._default_stream_decoder, + labels=labels, + ) + @overload def task( self, @@ -347,4 +376,37 @@ async def listen( raise ValueError("Please run startup before listening.") async for raw_kafka_message in self._aiokafka_consumer: - yield raw_kafka_message.value + subscriber = self._stream_subscribers.get(raw_kafka_message.topic) + if subscriber is None: + yield raw_kafka_message.value + continue + + yield self._build_stream_message(raw_kafka_message, subscriber) + + def _build_stream_message( + self, + raw_kafka_message: ConsumerRecord[Any, bytes], + subscriber: StreamSubscriber, + ) -> bytes: + raw_value = raw_kafka_message.value + decoded_value = subscriber.decoder(raw_value) + stream_message = self._normalize_stream_message(decoded_value) + labels = { + **subscriber.labels, + TASK_STREAM_LABEL: raw_kafka_message.topic, + } + message = TaskiqMessage( + task_id=self.id_generator(), + task_name=subscriber.task_name, + labels=labels, + labels_types={}, + args=list(stream_message.args), + kwargs=stream_message.kwargs, + ) + return self.formatter.dumps(message).message + + @staticmethod + def _normalize_stream_message(message: Any | StreamMessage) -> StreamMessage: + if isinstance(message, StreamMessage): + return message + return StreamMessage(args=(message,)) diff --git a/taskiq_aio_kafka/constants.py b/taskiq_aio_kafka/constants.py index 71f5fad..84f54ff 100644 --- a/taskiq_aio_kafka/constants.py +++ b/taskiq_aio_kafka/constants.py @@ -1,3 +1,4 @@ -__all__ = ("TASK_TOPIC_LABEL",) +__all__ = ("TASK_STREAM_LABEL", "TASK_TOPIC_LABEL") +TASK_STREAM_LABEL = "taskiq-stream" TASK_TOPIC_LABEL = "taskiq_aio_kafka_topic" diff --git a/taskiq_aio_kafka/subscriber.py b/taskiq_aio_kafka/subscriber.py new file mode 100644 index 0000000..5b8e72c --- /dev/null +++ b/taskiq_aio_kafka/subscriber.py @@ -0,0 +1,25 @@ +__all__ = ("StreamDecoder", "StreamMessage", "StreamSubscriber") + +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class StreamMessage: + """Decoded stream message arguments for a taskiq task.""" + + args: Sequence[Any] = () + kwargs: dict[str, Any] = field(default_factory=dict) + + +StreamDecoder = Callable[[bytes], Any | StreamMessage] + + +@dataclass(frozen=True) +class StreamSubscriber: + """Kafka stream subscriber bound to a taskiq task.""" + + task_name: str + decoder: StreamDecoder + labels: dict[str, Any] = field(default_factory=dict) diff --git a/tests/test_subscriber.py b/tests/test_subscriber.py new file mode 100644 index 0000000..0dbd7d0 --- /dev/null +++ b/tests/test_subscriber.py @@ -0,0 +1,359 @@ +from unittest.mock import Mock + +import pytest +from aiokafka.structs import ConsumerRecord +from kafka.admin import KafkaAdminClient + +from taskiq_aio_kafka.broker import AioKafkaBroker +from taskiq_aio_kafka.constants import TASK_STREAM_LABEL +from taskiq_aio_kafka.subscriber import StreamMessage +from taskiq_aio_kafka.topic import Topic + + +class _ConsumerMock: + """Kafka consumer mock.""" + + def __init__(self, messages: list[ConsumerRecord[None, bytes]]) -> None: + self._messages = iter(messages) + + def __aiter__(self) -> "_ConsumerMock": + return self + + async def __anext__(self) -> ConsumerRecord[None, bytes]: + """Return next message.""" + try: + return next(self._messages) + except StopIteration as exc: + raise StopAsyncIteration from exc + + +class _ProducerStartStopMock: + """Kafka producer lifecycle mock.""" + + async def start(self) -> None: + """Start producer.""" + + async def stop(self) -> None: + """Stop producer.""" + + +class _ConsumerStartStopMock: + """Kafka consumer lifecycle mock.""" + + async def start(self) -> None: + """Start consumer.""" + + async def stop(self) -> None: + """Stop consumer.""" + + +def get_admin_client_mock() -> KafkaAdminClient: + """Get kafka admin client mock.""" + admin_client = Mock(spec=KafkaAdminClient) + admin_client.list_topics.return_value = [] + return admin_client + + +def build_consumer_record(topic: str, value: bytes) -> ConsumerRecord[None, bytes]: + """Build Kafka consumer record.""" + return ConsumerRecord( + topic=topic, + partition=0, + offset=0, + timestamp=0, + timestamp_type=0, + key=None, + value=value, + checksum=None, + serialized_key_size=0, + serialized_value_size=len(value), + headers=[], + ) + + +async def get_first_task(broker: AioKafkaBroker) -> bytes: # type: ignore[return] + """Get first message from the topic.""" + async for message in broker.listen(): + return message + + +async def test_subscribe_wraps_raw_topic_message() -> None: + """Test that raw Kafka messages are wrapped as taskiq messages.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: str) -> None: + assert message + + broker.subscribe( + "stream-topic", + test_task, + decoder=lambda message: message.decode(), + ) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.kwargs == {} + assert taskiq_message.labels[TASK_STREAM_LABEL] == "stream-topic" + + +async def test_subscribe_without_decoder_passes_raw_bytes() -> None: + """Test that default stream decoder passes raw bytes.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.kwargs == {} + + +async def test_subscribe_accepts_topic_object_and_custom_labels() -> None: + """Test that stream subscriptions can use Topic objects.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + stream_topic = Topic("stream-topic") + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe(stream_topic, test_task, source="external") + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record(stream_topic.name, b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == ["raw-message"] + assert taskiq_message.labels == { + TASK_STREAM_LABEL: stream_topic.name, + "source": "external", + } + + +async def test_subscribe_keeps_stream_label_from_source_topic() -> None: + """Test that custom labels cannot override source stream label.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe( + "stream-topic", + test_task, + decoder=None, + **{TASK_STREAM_LABEL: "wrong-topic"}, + ) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.labels[TASK_STREAM_LABEL] == "stream-topic" + + +async def test_subscribe_decoder_can_return_task_args_and_kwargs() -> None: + """Test that stream messages can be mapped to task args and kwargs.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(user_id: int, email: str, active: bool) -> None: + assert user_id + assert email + assert active + + def decode_message(message: bytes) -> StreamMessage: + user_id, email, active = message.decode().split(":") + return StreamMessage( + args=(int(user_id),), + kwargs={ + "email": email, + "active": active == "true", + }, + ) + + broker.subscribe("stream-topic", test_task, decoder=decode_message) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"1:user@example.com:true")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_name == test_task.task_name + assert taskiq_message.args == [1] + assert taskiq_message.kwargs == { + "email": "user@example.com", + "active": True, + } + + +async def test_subscribe_uses_broker_task_id_generator() -> None: + """Test that wrapped stream messages use broker task id generator.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ).with_id_generator(lambda: "stream-task-id") + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("stream-topic", b"raw-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + taskiq_message = broker.formatter.loads(received_message) + + assert taskiq_message.task_id == "stream-task-id" + + +async def test_listen_keeps_regular_topic_messages_with_subscribers() -> None: + """Test that non-stream topics keep yielding raw taskiq messages.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker._aiokafka_consumer = _ConsumerMock( + [build_consumer_record("default-topic", b"taskiq-message")], + ) + broker._is_consumer_started = True + + received_message = await get_first_task(broker) + + assert received_message == b"taskiq-message" + + +async def test_startup_subscribes_consumer_to_stream_topic( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that worker consumer subscribes to stream topics.""" + consumer_topics: tuple[str, ...] = () + + def create_producer(**_kwargs: object) -> _ProducerStartStopMock: + return _ProducerStartStopMock() + + def create_consumer( + *topics: str, + **_kwargs: object, + ) -> _ConsumerStartStopMock: + nonlocal consumer_topics + consumer_topics = topics + return _ConsumerStartStopMock() + + monkeypatch.setattr( + "taskiq_aio_kafka.broker.AIOKafkaProducer", + create_producer, + ) + monkeypatch.setattr( + "taskiq_aio_kafka.broker.AIOKafkaConsumer", + create_consumer, + ) + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + broker.is_worker_process = True + + await broker.startup() + await broker.shutdown() + + assert consumer_topics == ("default-topic", "stream-topic") + + +def test_subscribe_adds_topic_to_broker_topics() -> None: + """Test that subscribed topic is included in listened topics.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + + assert set(broker._kafka_topics) == {"default-topic", "stream-topic"} + + +def test_subscribe_rejects_duplicate_topics() -> None: + """Test that stream topic cannot be subscribed twice.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + + @broker.task + async def test_task(message: bytes) -> None: + assert message + + broker.subscribe("stream-topic", test_task) + + with pytest.raises(ValueError, match="already subscribed"): + broker.subscribe("stream-topic", test_task) diff --git a/uv.lock b/uv.lock index 7533434..1abc153 100644 --- a/uv.lock +++ b/uv.lock @@ -1347,7 +1347,6 @@ wheels = [ [[package]] name = "taskiq-aio-kafka" -version = "0.2.1" source = { editable = "." } dependencies = [ { name = "aiokafka" },