diff --git a/docs/features.md b/docs/features.md index 0ccac74b..78f1fdb9 100644 --- a/docs/features.md +++ b/docs/features.md @@ -166,3 +166,69 @@ with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_cha **NOTE** The worker and client output many logs at the `DEBUG` level that will be useful when understanding orchestration flow and diagnosing issues with Durable applications. Before submitting issues, please attempt a repro of the issue with debug logging enabled. + +### Work item filtering + +By default a worker receives **all** work items from the backend, +regardless of which orchestrations, activities, or entities are +registered. Work item filtering lets you explicitly tell the backend +which work items a worker can handle so that only matching items are +dispatched. This is useful when running multiple specialized workers +against the same task hub. + +Work item filtering is **opt-in**. Call `use_work_item_filters()` on +the worker before starting it. + +#### Auto-generated filters + +Calling `use_work_item_filters()` with no arguments builds filters +automatically from the worker's registry at start time: + +```python +with DurableTaskSchedulerWorker(...) as w: + w.add_orchestrator(my_orchestrator) + w.add_activity(my_activity) + w.use_work_item_filters() # auto-generate from registry + w.start() +``` + +When versioning is configured with `VersionMatchStrategy.STRICT`, +the worker's version is included in every filter so the backend +only dispatches work items that match that exact version. + +#### Explicit filters + +Pass a `WorkItemFilters` instance for fine-grained control: + +```python +from durabletask.worker import ( + WorkItemFilters, + OrchestrationWorkItemFilter, + ActivityWorkItemFilter, + EntityWorkItemFilter, +) + +w.use_work_item_filters(WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter(name="my_orch", versions=["2.0.0"]), + ], + activities=[ + ActivityWorkItemFilter(name="my_activity"), + ], + entities=[ + EntityWorkItemFilter(name="my_entity"), + ], +)) +``` + +#### Clearing filters + +Pass `None` to clear any previously configured filters and return +to the default behaviour of processing all work items: + +```python +w.use_work_item_filters(None) +``` + +See the full +[work item filtering sample](../examples/work_item_filtering.py). diff --git a/docs/supported-patterns.md b/docs/supported-patterns.md index 612678a1..31a8ffb5 100644 --- a/docs/supported-patterns.md +++ b/docs/supported-patterns.md @@ -118,4 +118,49 @@ def my_orchestrator(ctx: task.OrchestrationContext, order: Order): return "Success" ``` -See the full [version-aware orchestrator sample](../examples/version_aware_orchestrator.py) \ No newline at end of file +See the full [version-aware orchestrator sample](../examples/version_aware_orchestrator.py) + +### Work item filtering + +When running multiple workers against the same task hub, each +worker can declare which work items it handles. The backend then +dispatches only the matching orchestrations, activities, and +entities, avoiding unnecessary round-trips. Filtering is opt-in +and supports both auto-generated and explicit filter sets. + +The simplest approach auto-generates filters from the worker's +registry: + +```python +with DurableTaskSchedulerWorker(...) as w: + w.add_orchestrator(greeting_orchestrator) + w.add_activity(greet) + w.use_work_item_filters() # auto-generate from registry + w.start() +``` + +For more control you can provide explicit filters, including +version constraints: + +```python +from durabletask.worker import ( + WorkItemFilters, + OrchestrationWorkItemFilter, + ActivityWorkItemFilter, +) + +w.use_work_item_filters(WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter( + name="greeting_orchestrator", + versions=["2.0.0"], + ), + ], + activities=[ + ActivityWorkItemFilter(name="greet"), + ], +)) +``` + +See the full +[work item filtering sample](../examples/work_item_filtering.py). diff --git a/durabletask/__init__.py b/durabletask/__init__.py index e0e73d30..1ab2a12f 100644 --- a/durabletask/__init__.py +++ b/durabletask/__init__.py @@ -3,8 +3,8 @@ """Durable Task SDK for Python""" -from durabletask.worker import ConcurrencyOptions, VersioningOptions +from durabletask.worker import ConcurrencyOptions, VersioningOptions, WorkItemFilters -__all__ = ["ConcurrencyOptions", "VersioningOptions"] +__all__ = ["ConcurrencyOptions", "VersioningOptions", "WorkItemFilters"] PACKAGE_NAME = "durabletask" diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index 590688ad..3dc28745 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -26,6 +26,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.helpers as helpers +from durabletask.entities.entity_instance_id import EntityInstanceId @dataclass @@ -56,6 +57,7 @@ class ActivityWorkItem: task_id: int input: Optional[str] completion_token: int + version: Optional[str] = None @dataclass @@ -436,9 +438,57 @@ def RestartInstance(self, request: pb.RestartInstanceRequest, context): f"Restarted instance '{request.instanceId}' as '{new_instance_id}'") return pb.RestartInstanceResponse(instanceId=new_instance_id) + @staticmethod + def _parse_work_item_filters(request: pb.GetWorkItemsRequest): + """Extract filters from the request. + + Returns a tuple of three values, one per work-item category. Each + value is either ``None`` (no filtering -- dispatch everything) or a + ``dict`` mapping a task name to a ``frozenset`` of accepted versions + (empty frozenset means *any* version of that name is accepted). + An empty ``dict`` means the worker opted into filtering for that + category but listed no names, so *nothing* should match. + """ + if not request.HasField("workItemFilters"): + return None, None, None + wf = request.workItemFilters + + def _build_filter(filters): + result: dict[str, frozenset[str]] = {} + for f in filters: + versions = frozenset(f.versions) if f.versions else frozenset() + existing = result.get(f.name, frozenset()) + result[f.name] = existing | versions + return result + + orch_filter = _build_filter(wf.orchestrations) + activity_filter = _build_filter(wf.activities) + entity_filter = {f.name: frozenset() for f in wf.entities} + return orch_filter, activity_filter, entity_filter + + @staticmethod + def _matches_filter(name: str, version: Optional[str], + filt: Optional[dict[str, frozenset[str]]]) -> bool: + """Check whether a work item matches the parsed filter. + + *filt* is ``None`` when the worker did not opt into filtering + (everything matches). Otherwise it is a dict mapping accepted + names to a frozenset of accepted versions. An empty frozenset + means any version of that name is accepted. + """ + if filt is None: + return True + accepted_versions = filt.get(name) + if accepted_versions is None: + return False + if not accepted_versions: + return True # empty set -- any version + return (version or "") in accepted_versions + def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): """Streams work items to the worker (orchestration and activity work items).""" self._logger.info("Worker connected and requesting work items") + orch_filter, activity_filter, entity_filter = self._parse_work_item_filters(request) try: while context.is_active() and not self._shutdown_event.is_set(): @@ -446,6 +496,7 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): with self._lock: # Check for orchestration work + skipped_orchs: list[str] = [] while self._orchestration_queue: instance_id = self._orchestration_queue.popleft() self._orchestration_queue_set.discard(instance_id) @@ -454,11 +505,15 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): if not instance or not instance.pending_events: continue + # Skip if orchestration doesn't match filters + if not self._matches_filter( + instance.name, instance.version, orch_filter): + skipped_orchs.append(instance_id) + continue + if instance_id in self._orchestration_in_flight: # Already being processed — re-add to queue - if instance_id not in self._orchestration_queue_set: - self._orchestration_queue.append(instance_id) - self._orchestration_queue_set.add(instance_id) + skipped_orchs.append(instance_id) break # Move pending events to dispatched_events @@ -485,27 +540,62 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): ) break + # Re-queue skipped orchestrations for other workers + for s in skipped_orchs: + if s not in self._orchestration_queue_set: + self._orchestration_queue.append(s) + self._orchestration_queue_set.add(s) + # Check for activity work if not work_item and self._activity_queue: - activity = self._activity_queue.popleft() - work_item = pb.WorkItem( - completionToken=str(activity.completion_token), - activityRequest=pb.ActivityRequest( - name=activity.name, - taskId=activity.task_id, - input=wrappers_pb2.StringValue(value=activity.input) if activity.input else None, - orchestrationInstance=pb.OrchestrationInstance(instanceId=activity.instance_id) + # Scan for the first matching activity + skipped: list = [] + matched_activity = None + while self._activity_queue: + candidate = self._activity_queue.popleft() + if not self._matches_filter( + candidate.name, candidate.version, + activity_filter): + skipped.append(candidate) + continue + matched_activity = candidate + break + # Put back non-matching items + for s in skipped: + self._activity_queue.append(s) + + if matched_activity is not None: + work_item = pb.WorkItem( + completionToken=str(matched_activity.completion_token), + activityRequest=pb.ActivityRequest( + name=matched_activity.name, + taskId=matched_activity.task_id, + input=wrappers_pb2.StringValue(value=matched_activity.input) if matched_activity.input else None, + orchestrationInstance=pb.OrchestrationInstance(instanceId=matched_activity.instance_id) + ) ) - ) # Check for entity work if not work_item: + skipped_entities: list[str] = [] while self._entity_queue: entity_id = self._entity_queue.popleft() self._entity_queue_set.discard(entity_id) entity = self._entities.get(entity_id) if entity and entity.pending_operations: + # Skip if entity name doesn't match filters + if entity_filter is not None: + try: + parsed = EntityInstanceId.parse(entity_id) + if not self._matches_filter( + parsed.entity, None, + entity_filter): + skipped_entities.append(entity_id) + continue + except ValueError: + pass + # Skip if this entity is already being processed if entity_id in self._entity_in_flight: continue @@ -532,6 +622,12 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): ) break + # Re-queue skipped entities for other workers + for s in skipped_entities: + if s not in self._entity_queue_set: + self._entity_queue.append(s) + self._entity_queue_set.add(s) + if work_item: yield work_item else: @@ -1259,12 +1355,15 @@ def _process_schedule_task_action(self, instance: OrchestrationInstance, instance.status = pb.ORCHESTRATION_STATUS_RUNNING # Queue activity for execution + task_version = schedule_task.version.value \ + if schedule_task.HasField("version") else None self._activity_queue.append(ActivityWorkItem( instance_id=instance.instance_id, name=task_name, task_id=task_id, input=input_value, - completion_token=instance.completion_token + completion_token=instance.completion_token, + version=task_version, )) self._work_available.set() diff --git a/durabletask/worker.py b/durabletask/worker.py index 9c7f2d46..abc51aca 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -13,7 +13,7 @@ from threading import Event, Thread from types import GeneratorType from enum import Enum -from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload import uuid from packaging.version import InvalidVersion, parse @@ -141,6 +141,131 @@ def __init__(self, version: Optional[str] = None, self.failure_strategy = failure_strategy +# Sentinel object used to distinguish "auto-generate filters" from "clear filters (None)". +_AUTO_GENERATE_FILTERS = object() + + +class OrchestrationWorkItemFilter: + """Specifies a filter for orchestration work items.""" + + def __init__(self, name: str, versions: Optional[list[str]] = None): + """Initialize an orchestration filter. + + Args: + name: The name of the orchestration to filter. + versions: Optional list of versions to filter. + """ + self.name = name + self.versions: list[str] = versions if versions is not None else [] + + +class ActivityWorkItemFilter: + """Specifies a filter for activity work items.""" + + def __init__(self, name: str, versions: Optional[list[str]] = None): + """Initialize an activity filter. + + Args: + name: The name of the activity to filter. + versions: Optional list of versions to filter. + """ + self.name = name + self.versions: list[str] = versions if versions is not None else [] + + +class EntityWorkItemFilter: + """Specifies a filter for entity work items.""" + + def __init__(self, name: str): + """Initialize an entity filter. + + Args: + name: The name of the entity to filter. + The name is normalized to lowercase to match + entity registration and instance ID conventions. + """ + EntityInstanceId.validate_entity_name(name) + self.name = name.lower() + + +class WorkItemFilters: + """Work item filters for a Durable Task Worker. + + These filters are passed to the backend and only work items matching the + filters will be processed by the worker. If no filters are provided, the + worker will process all work items. + + By default, no filters are applied. Call + :meth:`TaskHubGrpcWorker.use_work_item_filters` to enable filtering. + """ + + def __init__( + self, + orchestrations: Optional[list[OrchestrationWorkItemFilter]] = None, + activities: Optional[list[ActivityWorkItemFilter]] = None, + entities: Optional[list[EntityWorkItemFilter]] = None, + ): + """Initialize work item filters. + + Args: + orchestrations: List of orchestration filters. + activities: List of activity filters. + entities: List of entity filters. + """ + self.orchestrations: list[OrchestrationWorkItemFilter] = ( + orchestrations if orchestrations is not None else [] + ) + self.activities: list[ActivityWorkItemFilter] = ( + activities if activities is not None else [] + ) + self.entities: list[EntityWorkItemFilter] = ( + entities if entities is not None else [] + ) + + @classmethod + def _from_registry(cls, registry: '_Registry') -> 'WorkItemFilters': + """Auto-generate work item filters from the task registry.""" + versions: list[str] = [] + v = registry.versioning + if v and v.match_strategy == VersionMatchStrategy.STRICT and v.version: + versions = [registry.versioning.version] + + orchestrations = [ + OrchestrationWorkItemFilter(name=name, versions=list(versions)) + for name in registry.orchestrators + ] + activities = [ + ActivityWorkItemFilter(name=name, versions=list(versions)) + for name in registry.activities + ] + entities = [ + EntityWorkItemFilter(name=name) + for name in registry.entities + ] + return cls( + orchestrations=orchestrations, + activities=activities, + entities=entities, + ) + + def _to_grpc(self) -> pb.WorkItemFilters: + """Convert to a gRPC WorkItemFilters message.""" + grpc_filters = pb.WorkItemFilters() + for f in self.orchestrations: + grpc_filters.orchestrations.append( + pb.OrchestrationFilter(name=f.name, versions=f.versions) + ) + for f in self.activities: + grpc_filters.activities.append( + pb.ActivityFilter(name=f.name, versions=f.versions) + ) + for f in self.entities: + grpc_filters.entities.append( + pb.EntityFilter(name=f.name) + ) + return grpc_filters + + class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] @@ -350,6 +475,8 @@ def __init__( self._interceptors = None self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger) + self._work_item_filters: Optional[WorkItemFilters] = None + self._auto_generate_work_item_filters: bool = False @property def concurrency_options(self) -> ConcurrencyOptions: @@ -392,11 +519,62 @@ def use_versioning(self, version: VersioningOptions) -> None: raise RuntimeError("Cannot set default version while the worker is running.") self._registry.versioning = version + @overload + def use_work_item_filters(self) -> None: ... + + @overload + def use_work_item_filters(self, filters: WorkItemFilters) -> None: ... + + @overload + def use_work_item_filters(self, filters: None) -> None: ... + + def use_work_item_filters( + self, + filters: Union[WorkItemFilters, None, object] = _AUTO_GENERATE_FILTERS, + ) -> None: + """Configures work item filters for the worker. + + Work item filters tell the backend which orchestrations, activities, + and entities this worker can handle. When enabled, only matching work + items are dispatched to this worker. + + By default no filters are applied and the worker processes all work + items. Calling this method enables filtering. + + Args: + filters: The filters to apply. If omitted (default), filters are + auto-generated from registered orchestrations, activities, and + entities at :meth:`start` time. Pass a :class:`WorkItemFilters` + instance to provide explicit filters. Pass ``None`` to clear + any previously configured filters. + """ + if self._is_running: + raise RuntimeError( + "Work item filters cannot be changed while the worker is running." + ) + if filters is _AUTO_GENERATE_FILTERS: + self._auto_generate_work_item_filters = True + self._work_item_filters = None + elif filters is None: + self._auto_generate_work_item_filters = False + self._work_item_filters = None + elif isinstance(filters, WorkItemFilters): + self._auto_generate_work_item_filters = False + self._work_item_filters = filters + else: + raise TypeError( + "filters must be a WorkItemFilters instance, None, or omitted." + ) + def start(self): """Starts the worker on a background thread and begins listening for work items.""" if self._is_running: raise RuntimeError("The worker is already running.") + # Auto-generate work item filters from registry if opted in + if self._auto_generate_work_item_filters: + self._work_item_filters = WorkItemFilters._from_registry(self._registry) + def run_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -502,6 +680,10 @@ def should_invalidate_connection(rpc_error): maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items, ) + if self._work_item_filters is not None: + get_work_items_request.workItemFilters.CopyFrom( + self._work_item_filters._to_grpc() + ) self._response_stream = stub.GetWorkItems(get_work_items_request) self._logger.info( f"Successfully connected to {self._host_address}. Waiting for work items..." diff --git a/examples/work_item_filtering.py b/examples/work_item_filtering.py new file mode 100644 index 00000000..c0553e4c --- /dev/null +++ b/examples/work_item_filtering.py @@ -0,0 +1,104 @@ +"""End-to-end sample that demonstrates how to use work item filters +to control which orchestrations and activities a worker processes.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task, worker +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +# --- Activity definitions --- + +def greet(ctx: task.ActivityContext, name: str) -> str: + """Activity that returns a greeting.""" + return f"Hello, {name}!" + + +def farewell(ctx: task.ActivityContext, name: str) -> str: + """Activity that returns a farewell message.""" + return f"Goodbye, {name}!" + + +# --- Orchestrator definitions --- + +def greeting_orchestrator(ctx: task.OrchestrationContext, name: str): + """Orchestrator that calls the greet activity.""" + result = yield ctx.call_activity(greet, input=name) + return result + + +def farewell_orchestrator(ctx: task.OrchestrationContext, name: str): + """Orchestrator that calls the farewell activity.""" + result = yield ctx.call_activity(farewell, input=name) + return result + + +# --- Main --- + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +# Set credential to None for emulator, or DefaultAzureCredential for Azure +secure_channel = endpoint.startswith("https://") +credential = DefaultAzureCredential() if secure_channel else None + +# === Example 1: Auto-generated filters === +# Calling use_work_item_filters() with no arguments tells the worker to +# automatically build filters from the registered orchestrators, activities, +# and entities. The backend will then only dispatch matching work items. +print("\n--- Example 1: Auto-generated filters ---") +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(greeting_orchestrator) + w.add_activity(greet) + # Opt in to work item filtering — filters are derived from the registry + w.use_work_item_filters() + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(greeting_orchestrator, input="World") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f" Completed: {state.serialized_output}") + elif state: + print(f" Failed: {state.failure_details}") + +# === Example 2: Explicit / custom filters === +# You can supply your own WorkItemFilters to have fine-grained control +# over which work items the worker receives, including version constraints. +print("\n--- Example 2: Explicit filters ---") +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(greeting_orchestrator) + w.add_orchestrator(farewell_orchestrator) + w.add_activity(greet) + w.add_activity(farewell) + + # Only process greeting-related work items, ignoring farewell tasks + w.use_work_item_filters(worker.WorkItemFilters( + orchestrations=[ + worker.OrchestrationWorkItemFilter(name="greeting_orchestrator"), + ], + activities=[ + worker.ActivityWorkItemFilter(name="greet"), + ], + )) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(greeting_orchestrator, input="World") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f" Completed: {state.serialized_output}") + elif state: + print(f" Failed: {state.failure_details}") + + exit() diff --git a/tests/durabletask-azuremanaged/test_dts_work_item_filters_e2e.py b/tests/durabletask-azuremanaged/test_dts_work_item_filters_e2e.py new file mode 100644 index 00000000..b0945716 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_work_item_filters_e2e.py @@ -0,0 +1,264 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""E2E tests for work item filtering against DTS (emulator or deployed).""" + +import os +import time + +import pytest + +from durabletask import client, entities, task +from durabletask.worker import ( + ActivityWorkItemFilter, + EntityWorkItemFilter, + OrchestrationWorkItemFilter, + WorkItemFilters, +) +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + +def _plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + +def _orchestrator_with_activity(ctx: task.OrchestrationContext, start_val: int): + result = yield ctx.call_activity(_plus_one, input=start_val) + return result + + +def _other_orchestrator(ctx: task.OrchestrationContext, _): + return "other" + + +# ------------------------------------------------------------------ +# Tests: auto-generated filters +# ------------------------------------------------------------------ + +def test_auto_filters_processes_matching_work_items(): + """Worker with auto-generated filters processes matching orchestrations.""" + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.use_work_item_filters() + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=5) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "6" + + +# ------------------------------------------------------------------ +# Tests: explicit custom filters +# ------------------------------------------------------------------ + +def test_explicit_filters_matching(): + """Worker with explicit filters processes matching work items.""" + custom_filters = WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter( + name=task.get_name(_orchestrator_with_activity) + ) + ], + activities=[ + ActivityWorkItemFilter(name=task.get_name(_plus_one)) + ], + ) + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.use_work_item_filters(custom_filters) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=10) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "11" + + +# ------------------------------------------------------------------ +# Tests: no filters (default behavior) +# ------------------------------------------------------------------ + +def test_no_filters_processes_all(): + """Without filters the worker processes all work items.""" + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=7) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "8" + + +# ------------------------------------------------------------------ +# Tests: cleared filters (None) +# ------------------------------------------------------------------ + +def test_cleared_filters_processes_all(): + """Clearing filters with None restores process-all behavior.""" + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.use_work_item_filters() + w.use_work_item_filters(None) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=3) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "4" + + +# ------------------------------------------------------------------ +# Tests: entity work item filters +# ------------------------------------------------------------------ + +def test_entity_filters_process_matching_entity(): + """Worker with entity filters processes matching entity signals.""" + invoked = False + + class Counter(entities.DurableEntity): + def add(self, amount: int): + self.set_state(self.get_state(int, 0) + amount) + nonlocal invoked + invoked = True + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(Counter) + w.use_work_item_filters(WorkItemFilters( + entities=[EntityWorkItemFilter(name="counter")], + )) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("counter", "myKey") + c.signal_entity(entity_id, "add", input=10) + time.sleep(5) # wait for the signal to be processed + + state = c.get_entity(entity_id) + + assert invoked + assert state is not None + assert state.get_state(int) == 10 + + +# ------------------------------------------------------------------ +# Tests: non-matching filters prevent processing +# ------------------------------------------------------------------ + +def test_non_matching_orchestrator_not_processed(): + """Work items for unmatched orchestrations are not dispatched.""" + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_orchestrator(_other_orchestrator) + w.add_activity(_plus_one) + w.use_work_item_filters(WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter( + name=task.get_name(_other_orchestrator) + ), + ], + activities=[ + ActivityWorkItemFilter(name=task.get_name(_plus_one)), + ], + )) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + + # Schedule the non-matching orchestration — should NOT be processed + non_match_id = c.schedule_new_orchestration( + _orchestrator_with_activity, input=1) + + # Schedule the matching orchestration — should complete + match_id = c.schedule_new_orchestration(_other_orchestrator) + match_state = c.wait_for_orchestration_completion( + match_id, timeout=30) + + # The matching orchestration completes normally + assert match_state is not None + assert match_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert match_state.serialized_output == '"other"' + + # The non-matching orchestration should still be pending + non_match_state = c.get_orchestration_state(non_match_id) + assert non_match_state is not None + assert non_match_state.runtime_status == client.OrchestrationStatus.PENDING + + +def test_non_matching_entity_not_processed(): + """Work items for unmatched entities are not dispatched.""" + matched_invoked = False + unmatched_invoked = False + + class AllowedEntity(entities.DurableEntity): + def ping(self, _): + nonlocal matched_invoked + matched_invoked = True + + class BlockedEntity(entities.DurableEntity): + def ping(self, _): + nonlocal unmatched_invoked + unmatched_invoked = True + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(AllowedEntity) + w.add_entity(BlockedEntity) + w.use_work_item_filters(WorkItemFilters( + entities=[EntityWorkItemFilter(name="allowedentity")], + )) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + c.signal_entity( + entities.EntityInstanceId("allowedentity", "k1"), "ping") + c.signal_entity( + entities.EntityInstanceId("blockedentity", "k1"), "ping") + time.sleep(5) # wait for processing + + assert matched_invoked + assert not unmatched_invoked diff --git a/tests/durabletask/test_work_item_filters.py b/tests/durabletask/test_work_item_filters.py new file mode 100644 index 00000000..28491b48 --- /dev/null +++ b/tests/durabletask/test_work_item_filters.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.worker import ( + ActivityWorkItemFilter, + EntityWorkItemFilter, + OrchestrationWorkItemFilter, + TaskHubGrpcWorker, + VersioningOptions, + VersionMatchStrategy, + WorkItemFilters, + _Registry, +) + + +# --------------------------------------------------------------------------- +# OrchestrationWorkItemFilter / ActivityWorkItemFilter / EntityWorkItemFilter +# --------------------------------------------------------------------------- + +class TestOrchestrationWorkItemFilter: + def test_defaults(self): + f = OrchestrationWorkItemFilter(name="MyOrch") + assert f.name == "MyOrch" + assert f.versions == [] + + def test_with_versions(self): + f = OrchestrationWorkItemFilter(name="MyOrch", versions=["1.0", "2.0"]) + assert f.versions == ["1.0", "2.0"] + + +class TestActivityWorkItemFilter: + def test_defaults(self): + f = ActivityWorkItemFilter(name="MyActivity") + assert f.name == "MyActivity" + assert f.versions == [] + + def test_with_versions(self): + f = ActivityWorkItemFilter(name="MyActivity", versions=["3.0"]) + assert f.versions == ["3.0"] + + +class TestEntityWorkItemFilter: + def test_defaults(self): + f = EntityWorkItemFilter(name="myentity") + assert f.name == "myentity" + + def test_name_normalized_to_lowercase(self): + f = EntityWorkItemFilter(name="Counter") + assert f.name == "counter" + + def test_invalid_name_raises(self): + with pytest.raises(ValueError): + EntityWorkItemFilter(name="bad@name") + + def test_empty_name_raises(self): + with pytest.raises(ValueError): + EntityWorkItemFilter(name="") + + +# --------------------------------------------------------------------------- +# WorkItemFilters construction +# --------------------------------------------------------------------------- + +class TestWorkItemFilters: + def test_defaults_empty(self): + filters = WorkItemFilters() + assert filters.orchestrations == [] + assert filters.activities == [] + assert filters.entities == [] + + def test_explicit_values(self): + orch = [OrchestrationWorkItemFilter(name="Orch1")] + act = [ActivityWorkItemFilter(name="Act1")] + ent = [EntityWorkItemFilter(name="ent1")] + filters = WorkItemFilters( + orchestrations=orch, activities=act, entities=ent + ) + assert len(filters.orchestrations) == 1 + assert filters.orchestrations[0].name == "Orch1" + assert len(filters.activities) == 1 + assert filters.activities[0].name == "Act1" + assert len(filters.entities) == 1 + assert filters.entities[0].name == "ent1" + + +# --------------------------------------------------------------------------- +# WorkItemFilters._from_registry +# --------------------------------------------------------------------------- + +def _make_orchestrator(name): + """Create a minimal orchestrator function with the given name.""" + def orchestrator(ctx, input): + yield # pragma: no cover + orchestrator.__name__ = name + return orchestrator + + +def _make_activity(name): + """Create a minimal activity function with the given name.""" + def activity(ctx, input): + return None # pragma: no cover + activity.__name__ = name + return activity + + +def _make_entity(name): + """Create a minimal entity function with the given name.""" + def entity(ctx, state, input): + return None # pragma: no cover + entity.__name__ = name + return entity + + +class TestFromRegistry: + def test_empty_registry(self): + reg = _Registry() + filters = WorkItemFilters._from_registry(reg) + assert filters.orchestrations == [] + assert filters.activities == [] + assert filters.entities == [] + + def test_orchestrators_and_activities(self): + reg = _Registry() + reg.add_orchestrator(_make_orchestrator("Orch1")) + reg.add_orchestrator(_make_orchestrator("Orch2")) + reg.add_activity(_make_activity("Act1")) + + filters = WorkItemFilters._from_registry(reg) + + orch_names = {f.name for f in filters.orchestrations} + assert orch_names == {"Orch1", "Orch2"} + assert all(f.versions == [] for f in filters.orchestrations) + + assert len(filters.activities) == 1 + assert filters.activities[0].name == "Act1" + assert filters.activities[0].versions == [] + + assert filters.entities == [] + + def test_entities(self): + reg = _Registry() + reg.add_entity(_make_entity("counter"), name="counter") + + filters = WorkItemFilters._from_registry(reg) + + assert len(filters.entities) == 1 + assert filters.entities[0].name == "counter" + + def test_no_versioning(self): + """Without versioning, versions should be empty.""" + reg = _Registry() + reg.add_orchestrator(_make_orchestrator("Orch")) + reg.add_activity(_make_activity("Act")) + + filters = WorkItemFilters._from_registry(reg) + assert filters.orchestrations[0].versions == [] + assert filters.activities[0].versions == [] + + def test_versioning_none_strategy(self): + """NONE match strategy should produce empty versions.""" + reg = _Registry() + reg.add_orchestrator(_make_orchestrator("Orch")) + reg.versioning = VersioningOptions( + version="1.0", match_strategy=VersionMatchStrategy.NONE + ) + + filters = WorkItemFilters._from_registry(reg) + assert filters.orchestrations[0].versions == [] + + def test_versioning_current_or_older(self): + """CURRENT_OR_OLDER match strategy should produce empty versions.""" + reg = _Registry() + reg.add_orchestrator(_make_orchestrator("Orch")) + reg.versioning = VersioningOptions( + version="1.5", match_strategy=VersionMatchStrategy.CURRENT_OR_OLDER + ) + + filters = WorkItemFilters._from_registry(reg) + assert filters.orchestrations[0].versions == [] + + def test_versioning_strict(self): + """STRICT match strategy should populate versions.""" + reg = _Registry() + reg.add_orchestrator(_make_orchestrator("Orch")) + reg.add_activity(_make_activity("Act")) + reg.versioning = VersioningOptions( + version="2.0", match_strategy=VersionMatchStrategy.STRICT + ) + + filters = WorkItemFilters._from_registry(reg) + assert filters.orchestrations[0].versions == ["2.0"] + assert filters.activities[0].versions == ["2.0"] + + def test_versioning_strict_no_version_string(self): + """STRICT without a version string should produce empty versions.""" + reg = _Registry() + reg.add_orchestrator(_make_orchestrator("Orch")) + reg.versioning = VersioningOptions( + version=None, match_strategy=VersionMatchStrategy.STRICT + ) + + filters = WorkItemFilters._from_registry(reg) + assert filters.orchestrations[0].versions == [] + + +# --------------------------------------------------------------------------- +# WorkItemFilters._to_grpc +# --------------------------------------------------------------------------- + +class TestToGrpc: + def test_empty_filters(self): + grpc_msg = WorkItemFilters()._to_grpc() + assert isinstance(grpc_msg, pb.WorkItemFilters) + assert len(grpc_msg.orchestrations) == 0 + assert len(grpc_msg.activities) == 0 + assert len(grpc_msg.entities) == 0 + + def test_orchestration_filter(self): + filters = WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter(name="Orch1", versions=["1.0", "2.0"]), + ] + ) + grpc_msg = filters._to_grpc() + assert len(grpc_msg.orchestrations) == 1 + assert grpc_msg.orchestrations[0].name == "Orch1" + assert list(grpc_msg.orchestrations[0].versions) == ["1.0", "2.0"] + + def test_activity_filter(self): + filters = WorkItemFilters( + activities=[ + ActivityWorkItemFilter(name="Act1", versions=["3.0"]), + ActivityWorkItemFilter(name="Act2"), + ] + ) + grpc_msg = filters._to_grpc() + assert len(grpc_msg.activities) == 2 + assert grpc_msg.activities[0].name == "Act1" + assert list(grpc_msg.activities[0].versions) == ["3.0"] + assert grpc_msg.activities[1].name == "Act2" + assert list(grpc_msg.activities[1].versions) == [] + + def test_entity_filter(self): + filters = WorkItemFilters( + entities=[EntityWorkItemFilter(name="counter")] + ) + grpc_msg = filters._to_grpc() + assert len(grpc_msg.entities) == 1 + assert grpc_msg.entities[0].name == "counter" + + def test_full_round_trip(self): + """All three filter types convert correctly.""" + filters = WorkItemFilters( + orchestrations=[OrchestrationWorkItemFilter(name="O", versions=["1"])], + activities=[ActivityWorkItemFilter(name="A")], + entities=[EntityWorkItemFilter(name="e")], + ) + grpc_msg = filters._to_grpc() + assert grpc_msg.orchestrations[0].name == "O" + assert grpc_msg.activities[0].name == "A" + assert grpc_msg.entities[0].name == "e" + + +# --------------------------------------------------------------------------- +# TaskHubGrpcWorker.use_work_item_filters +# --------------------------------------------------------------------------- + +class TestUseWorkItemFilters: + def test_auto_generate_default(self): + """Calling with no arguments enables auto-generation.""" + w = TaskHubGrpcWorker() + w.use_work_item_filters() + assert w._auto_generate_work_item_filters is True + assert w._work_item_filters is None + + def test_explicit_filters(self): + """Passing a WorkItemFilters instance stores it directly.""" + w = TaskHubGrpcWorker() + custom = WorkItemFilters( + orchestrations=[OrchestrationWorkItemFilter(name="MyOrch")] + ) + w.use_work_item_filters(custom) + assert w._auto_generate_work_item_filters is False + assert w._work_item_filters is custom + assert len(w._work_item_filters.orchestrations) == 1 + + def test_clear_filters_with_none(self): + """Passing None clears previously set filters.""" + w = TaskHubGrpcWorker() + # First set some filters + w.use_work_item_filters(WorkItemFilters( + orchestrations=[OrchestrationWorkItemFilter(name="X")] + )) + assert w._work_item_filters is not None + + # Now clear + w.use_work_item_filters(None) + assert w._auto_generate_work_item_filters is False + assert w._work_item_filters is None + + def test_clear_auto_generate_with_none(self): + """Passing None after auto-generate clears the auto flag.""" + w = TaskHubGrpcWorker() + w.use_work_item_filters() + assert w._auto_generate_work_item_filters is True + + w.use_work_item_filters(None) + assert w._auto_generate_work_item_filters is False + assert w._work_item_filters is None + + def test_invalid_type_raises(self): + """Passing an unsupported type raises TypeError.""" + w = TaskHubGrpcWorker() + with pytest.raises(TypeError, match="WorkItemFilters instance"): + w.use_work_item_filters("invalid") # type: ignore + + def test_raises_when_running(self): + """Cannot change filters while worker is running.""" + w = TaskHubGrpcWorker() + w._is_running = True + with pytest.raises(RuntimeError, match="cannot be changed while the worker is running"): + w.use_work_item_filters() + + def test_default_no_filters(self): + """By default, no filters are set (opt-in model).""" + w = TaskHubGrpcWorker() + assert w._work_item_filters is None + assert w._auto_generate_work_item_filters is False diff --git a/tests/durabletask/test_work_item_filters_e2e.py b/tests/durabletask/test_work_item_filters_e2e.py new file mode 100644 index 00000000..5eec8fe4 --- /dev/null +++ b/tests/durabletask/test_work_item_filters_e2e.py @@ -0,0 +1,388 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for work item filtering using the in-memory backend.""" + +import time + +import pytest + +from durabletask import client, entities, task, worker +from durabletask.worker import ( + ActivityWorkItemFilter, + EntityWorkItemFilter, + OrchestrationWorkItemFilter, + VersioningOptions, + VersionMatchStrategy, + WorkItemFilters, +) +from durabletask.testing import create_test_backend + +HOST = "localhost:50060" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for testing.""" + b = create_test_backend(port=50060) + yield b + b.stop() + b.reset() + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + +def _plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + +def _multiply(_: task.ActivityContext, input: int) -> int: + return input * 2 + + +def _orchestrator_with_activity(ctx: task.OrchestrationContext, start_val: int): + result = yield ctx.call_activity(_plus_one, input=start_val) + return result + + +def _other_orchestrator(ctx: task.OrchestrationContext, _): + return "other" + + +# ------------------------------------------------------------------ +# Tests: auto-generated filters +# ------------------------------------------------------------------ + +def test_auto_filters_processes_matching_work_items(): + """Worker with auto-generated filters processes matching orchestrations.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.use_work_item_filters() + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=5) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "6" + + +# ------------------------------------------------------------------ +# Tests: explicit custom filters +# ------------------------------------------------------------------ + +def test_explicit_filters_matching(): + """Worker with explicit filters matching registered tasks processes work items.""" + custom_filters = WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter( + name=task.get_name(_orchestrator_with_activity) + ) + ], + activities=[ + ActivityWorkItemFilter(name=task.get_name(_plus_one)) + ], + ) + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.use_work_item_filters(custom_filters) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=10) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "11" + + +# ------------------------------------------------------------------ +# Tests: no filters (default behavior) +# ------------------------------------------------------------------ + +def test_no_filters_processes_all(): + """Without filters, worker processes all work items (default behavior).""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + # Intentionally do NOT call use_work_item_filters() + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=7) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "8" + + +# ------------------------------------------------------------------ +# Tests: cleared filters (None) +# ------------------------------------------------------------------ + +def test_cleared_filters_processes_all(): + """Clearing filters with None restores process-all behavior.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_orchestrator_with_activity) + w.add_activity(_plus_one) + w.use_work_item_filters() # auto-generate + w.use_work_item_filters(None) # then clear + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(_orchestrator_with_activity, input=3) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "4" + + +# ------------------------------------------------------------------ +# Tests: entity work item filters +# ------------------------------------------------------------------ + +def test_entity_filters_process_matching_entity(): + """Worker with entity filters processes matching entity signals.""" + invoked = False + + class Counter(entities.DurableEntity): + def add(self, amount: int): + self.set_state(self.get_state(int, 0) + amount) + nonlocal invoked + invoked = True + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(Counter) + w.use_work_item_filters(WorkItemFilters( + entities=[EntityWorkItemFilter(name="counter")], + )) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + entity_id = entities.EntityInstanceId("counter", "myKey") + c.signal_entity(entity_id, "add", input=10) + time.sleep(2) # wait for the signal to be processed + + state = c.get_entity(entity_id, include_state=True) + + assert invoked + assert state is not None + assert state.get_state(int) == 10 + + +# ------------------------------------------------------------------ +# Tests: non-matching filters prevent processing +# ------------------------------------------------------------------ + +def test_non_matching_orchestrator_not_processed(): + """Work items for unmatched orchestrations are not processed.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + # Register both orchestrators but only filter for one + w.add_orchestrator(_orchestrator_with_activity) + w.add_orchestrator(_other_orchestrator) + w.add_activity(_plus_one) + w.use_work_item_filters(WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter( + name=task.get_name(_other_orchestrator) + ), + ], + activities=[ + ActivityWorkItemFilter(name=task.get_name(_plus_one)), + ], + )) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + + # Schedule the non-matching orchestration — should NOT be processed + non_match_id = c.schedule_new_orchestration( + _orchestrator_with_activity, input=1) + + # Schedule the matching orchestration — should complete + match_id = c.schedule_new_orchestration(_other_orchestrator) + match_state = c.wait_for_orchestration_completion( + match_id, timeout=30) + + # The matching orchestration completes normally + assert match_state is not None + assert match_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert match_state.serialized_output == '"other"' + + # The non-matching orchestration should still be pending + non_match_state = c.get_orchestration_state(non_match_id) + assert non_match_state is not None + assert non_match_state.runtime_status == client.OrchestrationStatus.PENDING + + +def test_non_matching_entity_not_processed(): + """Work items for unmatched entities are not processed.""" + matched_invoked = False + unmatched_invoked = False + + class AllowedEntity(entities.DurableEntity): + def ping(self, _): + nonlocal matched_invoked + matched_invoked = True + + class BlockedEntity(entities.DurableEntity): + def ping(self, _): + nonlocal unmatched_invoked + unmatched_invoked = True + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(AllowedEntity) + w.add_entity(BlockedEntity) + # Only filter for AllowedEntity + w.use_work_item_filters(WorkItemFilters( + entities=[EntityWorkItemFilter(name="allowedentity")], + )) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + c.signal_entity( + entities.EntityInstanceId("allowedentity", "k1"), "ping") + c.signal_entity( + entities.EntityInstanceId("blockedentity", "k1"), "ping") + time.sleep(3) # wait for processing + + assert matched_invoked + assert not unmatched_invoked + + +# ------------------------------------------------------------------ +# Tests: version-aware filtering with strict versioning +# ------------------------------------------------------------------ + +def _simple_v2_orchestrator(ctx: task.OrchestrationContext, input: int): + """Orchestrator that returns immediately (no activities) for version tests.""" + return input + 1 + + +def test_strict_version_matching_orchestration_completes(): + """Orchestration scheduled with the matching version is processed.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_simple_v2_orchestrator) + w.use_versioning(VersioningOptions( + version="2.0", + match_strategy=VersionMatchStrategy.STRICT, + )) + w.use_work_item_filters() # auto-generate with version + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=10, version="2.0") + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == "11" + + +def test_strict_version_incompatible_orchestration_stays_pending(): + """Orchestration with an incompatible version is not dispatched and stays pending.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_simple_v2_orchestrator) + w.use_versioning(VersioningOptions( + version="2.0", + match_strategy=VersionMatchStrategy.STRICT, + )) + w.use_work_item_filters() + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + + # Schedule with version "1.0" — incompatible with the worker's "2.0" + bad_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=5, version="1.0") + + # Schedule a compatible one so we can confirm the worker is active + good_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=5, version="2.0") + good_state = c.wait_for_orchestration_completion(good_id, timeout=30) + + assert good_state is not None + assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED + + # The incompatible orchestration must remain pending (not failed) + bad_state = c.get_orchestration_state(bad_id) + assert bad_state is not None + assert bad_state.runtime_status == client.OrchestrationStatus.PENDING + + +def test_strict_version_no_version_orchestration_stays_pending(): + """Orchestration scheduled without a version is not dispatched by a strict worker.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_simple_v2_orchestrator) + w.use_versioning(VersioningOptions( + version="2.0", + match_strategy=VersionMatchStrategy.STRICT, + )) + w.use_work_item_filters() + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + + # Schedule without any version + no_ver_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1) + + # Schedule a compatible one to prove the worker is running + good_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1, version="2.0") + good_state = c.wait_for_orchestration_completion(good_id, timeout=30) + assert good_state is not None + assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED + + # The unversioned orchestration must remain pending + no_ver_state = c.get_orchestration_state(no_ver_id) + assert no_ver_state is not None + assert no_ver_state.runtime_status == client.OrchestrationStatus.PENDING + + +def test_strict_version_explicit_filters_with_versions(): + """Explicit filters with version constraints enforce strict matching.""" + custom_filters = WorkItemFilters( + orchestrations=[ + OrchestrationWorkItemFilter( + name=task.get_name(_simple_v2_orchestrator), + versions=["3.0"], + ), + ], + ) + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(_simple_v2_orchestrator) + w.use_work_item_filters(custom_filters) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + + # Version "2.0" does not match the filter's "3.0" + bad_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1, version="2.0") + + # Version "3.0" should match + good_id = c.schedule_new_orchestration( + _simple_v2_orchestrator, input=1, version="3.0") + good_state = c.wait_for_orchestration_completion(good_id, timeout=30) + + assert good_state is not None + assert good_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert good_state.serialized_output == "2" + + # Mismatched version must remain pending + bad_state = c.get_orchestration_state(bad_id) + assert bad_state is not None + assert bad_state.runtime_status == client.OrchestrationStatus.PENDING