Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,64 @@ async def regular_task() -> None:
await regular_task.kiq()
```

## Stream topics

You can subscribe an existing task to a Kafka topic with raw messages.
Messages from subscribed stream topics are wrapped into regular taskiq
messages before execution, so result backends and worker middlewares keep
working as usual.

```python
import json

from taskiq_aio_kafka import AioKafkaBroker

broker = AioKafkaBroker(
bootstrap_servers="localhost",
kafka_topic="taskiq-topic",
)


@broker.task
async def process_user_created(event: dict[str, object]) -> None:
print(event)


broker.subscribe(
"users.created",
process_user_created,
decoder=json.loads,
)
```

The decoded value is passed as the first task argument. A decoder can also
return `StreamMessage` to map one Kafka message into the exact `args` and
`kwargs` expected by the task.

```python
import json

from taskiq_aio_kafka import StreamMessage


@broker.task
async def process_user(user_id: int, email: str) -> None:
print(user_id, email)


def decode_user_event(message: bytes) -> StreamMessage:
event = json.loads(message)
return StreamMessage(
args=(event["user"]["id"],),
kwargs={"email": event["user"]["email"]},
)


broker.subscribe("users.created", process_user, decoder=decode_user_event)
```

Stream messages also receive the `taskiq-stream` label with the source topic name.

## Configuration

AioKafkaBroker parameters:
Expand Down
3 changes: 2 additions & 1 deletion taskiq_aio_kafka/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Taskiq integration with aiokafka."""

__all__ = ("AioKafkaBroker",)
__all__ = ("AioKafkaBroker", "StreamMessage")

from taskiq_aio_kafka.broker import AioKafkaBroker
from taskiq_aio_kafka.subscriber import StreamMessage
66 changes: 64 additions & 2 deletions taskiq_aio_kafka/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from typing import Any, TypeVar, overload

from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
from aiokafka.structs import ConsumerRecord
from kafka.admin import KafkaAdminClient, NewTopic
from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor
from kafka.partitioner.default import DefaultPartitioner
from taskiq import AsyncResultBackend, BrokerMessage
from taskiq.abc.broker import AsyncBroker
from taskiq.decor import AsyncTaskiqDecoratedTask
from taskiq.message import TaskiqMessage
from typing_extensions import ParamSpec

from .constants import TASK_TOPIC_LABEL
from .constants import TASK_STREAM_LABEL, TASK_TOPIC_LABEL
from .decorated_task import AioKafkaDecoratedTask
from .exceptions import WrongAioKafkaBrokerParametersError
from .models import KafkaConsumerParameters, KafkaProducerParameters
from .subscriber import StreamDecoder, StreamMessage, StreamSubscriber
from .topic import Topic
from .types import TopicType
from .utils import get_topic_name
Expand Down Expand Up @@ -79,6 +83,7 @@ def __init__( # noqa: PLR0913
get_topic_name(topic),
topic,
)
self._stream_subscribers: dict[str, StreamSubscriber] = {}

self._aiokafka_producer_params: KafkaProducerParameters = (
KafkaProducerParameters()
Expand Down Expand Up @@ -165,6 +170,30 @@ def configure_consumer(self, **consumer_parameters: Any) -> None:
**consumer_parameters,
)

@staticmethod
def _default_stream_decoder(message: bytes) -> StreamMessage:
return StreamMessage(args=(message,))

def subscribe(
self,
topic: TopicType,
task: AsyncTaskiqDecoratedTask[Any, Any],
decoder: StreamDecoder | None = None,
**labels: Any,
) -> None:
"""Subscribe task to raw Kafka topic messages."""
topic_name = get_topic_name(topic)
if topic_name in self._stream_subscribers:
error_message = f"Topic {topic_name!r} is already subscribed."
raise ValueError(error_message)

self._kafka_topics.setdefault(topic_name, topic)
self._stream_subscribers[topic_name] = StreamSubscriber(
task_name=task.task_name,
decoder=decoder or self._default_stream_decoder,
labels=labels,
)

@overload
def task(
self,
Expand Down Expand Up @@ -347,4 +376,37 @@ async def listen(
raise ValueError("Please run startup before listening.")

async for raw_kafka_message in self._aiokafka_consumer:
yield raw_kafka_message.value
subscriber = self._stream_subscribers.get(raw_kafka_message.topic)
if subscriber is None:
yield raw_kafka_message.value
continue

yield self._build_stream_message(raw_kafka_message, subscriber)

def _build_stream_message(
self,
raw_kafka_message: ConsumerRecord[Any, bytes],
subscriber: StreamSubscriber,
) -> bytes:
raw_value = raw_kafka_message.value
decoded_value = subscriber.decoder(raw_value)
stream_message = self._normalize_stream_message(decoded_value)
labels = {
**subscriber.labels,
TASK_STREAM_LABEL: raw_kafka_message.topic,
}
message = TaskiqMessage(
task_id=self.id_generator(),
task_name=subscriber.task_name,
labels=labels,
labels_types={},
args=list(stream_message.args),
kwargs=stream_message.kwargs,
)
return self.formatter.dumps(message).message

@staticmethod
def _normalize_stream_message(message: Any | StreamMessage) -> StreamMessage:
if isinstance(message, StreamMessage):
return message
return StreamMessage(args=(message,))
3 changes: 2 additions & 1 deletion taskiq_aio_kafka/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ("TASK_TOPIC_LABEL",)
__all__ = ("TASK_STREAM_LABEL", "TASK_TOPIC_LABEL")

TASK_STREAM_LABEL = "taskiq-stream"
TASK_TOPIC_LABEL = "taskiq_aio_kafka_topic"
25 changes: 25 additions & 0 deletions taskiq_aio_kafka/subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
__all__ = ("StreamDecoder", "StreamMessage", "StreamSubscriber")

from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any


@dataclass(frozen=True)
class StreamMessage:
"""Decoded stream message arguments for a taskiq task."""

args: Sequence[Any] = ()
kwargs: dict[str, Any] = field(default_factory=dict)


StreamDecoder = Callable[[bytes], Any | StreamMessage]


@dataclass(frozen=True)
class StreamSubscriber:
"""Kafka stream subscriber bound to a taskiq task."""

task_name: str
decoder: StreamDecoder
labels: dict[str, Any] = field(default_factory=dict)
Loading
Loading