-
Notifications
You must be signed in to change notification settings - Fork 25
Implement work item filtering #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
| import durabletask.internal.orchestrator_service_pb2 as pb | ||
| import durabletask.internal.orchestrator_service_pb2_grpc as stubs | ||
| import durabletask.internal.helpers as helpers | ||
| from durabletask.entities.entity_instance_id import EntityInstanceId | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -56,6 +57,7 @@ class ActivityWorkItem: | |
| task_id: int | ||
| input: Optional[str] | ||
| completion_token: int | ||
| version: Optional[str] = None | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -436,16 +438,65 @@ def RestartInstance(self, request: pb.RestartInstanceRequest, context): | |
| f"Restarted instance '{request.instanceId}' as '{new_instance_id}'") | ||
| return pb.RestartInstanceResponse(instanceId=new_instance_id) | ||
|
|
||
| @staticmethod | ||
| def _parse_work_item_filters(request: pb.GetWorkItemsRequest): | ||
| """Extract filters from the request. | ||
|
|
||
| Returns a tuple of three values, one per work-item category. Each | ||
| value is either ``None`` (no filtering -- dispatch everything) or a | ||
| ``dict`` mapping a task name to a ``frozenset`` of accepted versions | ||
| (empty frozenset means *any* version of that name is accepted). | ||
| An empty ``dict`` means the worker opted into filtering for that | ||
| category but listed no names, so *nothing* should match. | ||
| """ | ||
| if not request.HasField("workItemFilters"): | ||
| return None, None, None | ||
| wf = request.workItemFilters | ||
|
|
||
| def _build_filter(filters): | ||
| result: dict[str, frozenset[str]] = {} | ||
| for f in filters: | ||
| versions = frozenset(f.versions) if f.versions else frozenset() | ||
| existing = result.get(f.name, frozenset()) | ||
| result[f.name] = existing | versions | ||
| return result | ||
|
|
||
| orch_filter = _build_filter(wf.orchestrations) | ||
| activity_filter = _build_filter(wf.activities) | ||
| entity_filter = {f.name: frozenset() for f in wf.entities} | ||
| return orch_filter, activity_filter, entity_filter | ||
|
|
||
| @staticmethod | ||
| def _matches_filter(name: str, version: Optional[str], | ||
| filt: Optional[dict[str, frozenset[str]]]) -> bool: | ||
| """Check whether a work item matches the parsed filter. | ||
|
|
||
| *filt* is ``None`` when the worker did not opt into filtering | ||
| (everything matches). Otherwise it is a dict mapping accepted | ||
| names to a frozenset of accepted versions. An empty frozenset | ||
| means any version of that name is accepted. | ||
| """ | ||
| if filt is None: | ||
| return True | ||
| accepted_versions = filt.get(name) | ||
| if accepted_versions is None: | ||
| return False | ||
| if not accepted_versions: | ||
| return True # empty set -- any version | ||
| return (version or "") in accepted_versions | ||
|
|
||
| def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): | ||
| """Streams work items to the worker (orchestration and activity work items).""" | ||
| self._logger.info("Worker connected and requesting work items") | ||
| orch_filter, activity_filter, entity_filter = self._parse_work_item_filters(request) | ||
|
|
||
| try: | ||
| while context.is_active() and not self._shutdown_event.is_set(): | ||
| work_item = None | ||
|
|
||
| with self._lock: | ||
| # Check for orchestration work | ||
| skipped_orchs: list[str] = [] | ||
| while self._orchestration_queue: | ||
| instance_id = self._orchestration_queue.popleft() | ||
| self._orchestration_queue_set.discard(instance_id) | ||
|
|
@@ -454,11 +505,15 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): | |
| if not instance or not instance.pending_events: | ||
| continue | ||
|
|
||
| # Skip if orchestration doesn't match filters | ||
| if not self._matches_filter( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If Disclaimer: This review was generated by GitHub Copilot on behalf of Bernd. |
||
| instance.name, instance.version, orch_filter): | ||
| skipped_orchs.append(instance_id) | ||
| continue | ||
|
|
||
| if instance_id in self._orchestration_in_flight: | ||
| # Already being processed — re-add to queue | ||
| if instance_id not in self._orchestration_queue_set: | ||
| self._orchestration_queue.append(instance_id) | ||
| self._orchestration_queue_set.add(instance_id) | ||
| skipped_orchs.append(instance_id) | ||
| break | ||
|
|
||
| # Move pending events to dispatched_events | ||
|
|
@@ -485,27 +540,62 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): | |
| ) | ||
| break | ||
|
|
||
| # Re-queue skipped orchestrations for other workers | ||
| for s in skipped_orchs: | ||
| if s not in self._orchestration_queue_set: | ||
| self._orchestration_queue.append(s) | ||
| self._orchestration_queue_set.add(s) | ||
|
|
||
| # Check for activity work | ||
| if not work_item and self._activity_queue: | ||
| activity = self._activity_queue.popleft() | ||
| work_item = pb.WorkItem( | ||
| completionToken=str(activity.completion_token), | ||
| activityRequest=pb.ActivityRequest( | ||
| name=activity.name, | ||
| taskId=activity.task_id, | ||
| input=wrappers_pb2.StringValue(value=activity.input) if activity.input else None, | ||
| orchestrationInstance=pb.OrchestrationInstance(instanceId=activity.instance_id) | ||
| # Scan for the first matching activity | ||
| skipped: list = [] | ||
| matched_activity = None | ||
| while self._activity_queue: | ||
| candidate = self._activity_queue.popleft() | ||
| if not self._matches_filter( | ||
| candidate.name, candidate.version, | ||
| activity_filter): | ||
| skipped.append(candidate) | ||
| continue | ||
| matched_activity = candidate | ||
| break | ||
| # Put back non-matching items | ||
| for s in skipped: | ||
| self._activity_queue.append(s) | ||
|
|
||
| if matched_activity is not None: | ||
| work_item = pb.WorkItem( | ||
| completionToken=str(matched_activity.completion_token), | ||
| activityRequest=pb.ActivityRequest( | ||
| name=matched_activity.name, | ||
| taskId=matched_activity.task_id, | ||
| input=wrappers_pb2.StringValue(value=matched_activity.input) if matched_activity.input else None, | ||
| orchestrationInstance=pb.OrchestrationInstance(instanceId=matched_activity.instance_id) | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| # Check for entity work | ||
| if not work_item: | ||
| skipped_entities: list[str] = [] | ||
| while self._entity_queue: | ||
| entity_id = self._entity_queue.popleft() | ||
| self._entity_queue_set.discard(entity_id) | ||
| entity = self._entities.get(entity_id) | ||
|
|
||
| if entity and entity.pending_operations: | ||
| # Skip if entity name doesn't match filters | ||
| if entity_filter is not None: | ||
| try: | ||
| parsed = EntityInstanceId.parse(entity_id) | ||
| if not self._matches_filter( | ||
| parsed.entity, None, | ||
| entity_filter): | ||
| skipped_entities.append(entity_id) | ||
| continue | ||
| except ValueError: | ||
| pass | ||
|
|
||
| # Skip if this entity is already being processed | ||
| if entity_id in self._entity_in_flight: | ||
| continue | ||
|
|
@@ -532,6 +622,12 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): | |
| ) | ||
| break | ||
|
|
||
| # Re-queue skipped entities for other workers | ||
| for s in skipped_entities: | ||
| if s not in self._entity_queue_set: | ||
| self._entity_queue.append(s) | ||
| self._entity_queue_set.add(s) | ||
|
|
||
| if work_item: | ||
| yield work_item | ||
| else: | ||
|
|
@@ -1259,12 +1355,15 @@ def _process_schedule_task_action(self, instance: OrchestrationInstance, | |
| instance.status = pb.ORCHESTRATION_STATUS_RUNNING | ||
|
|
||
| # Queue activity for execution | ||
| task_version = schedule_task.version.value \ | ||
| if schedule_task.HasField("version") else None | ||
| self._activity_queue.append(ActivityWorkItem( | ||
| instance_id=instance.instance_id, | ||
| name=task_name, | ||
| task_id=task_id, | ||
| input=input_value, | ||
| completion_token=instance.completion_token | ||
| completion_token=instance.completion_token, | ||
| version=task_version, | ||
| )) | ||
| self._work_available.set() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider also exporting
OrchestrationWorkItemFilter,ActivityWorkItemFilter, andEntityWorkItemFilterfrom__init__.py. Users constructing explicit filters need these types (as shown in the PR's own example and docs), and in the .NET implementation these are accessible as nested types onDurableTaskWorkerWorkItemFilters. Exporting them improves public API discoverability.Disclaimer: This review was generated by GitHub Copilot on behalf of Bernd.