Skip to content

Commit c658a52

Browse files
andystaplesCopilot
andauthored
Add batch actions (purge, query orchestrations/entities) (#111)
* Add batch actions, tests --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: andystaples <77818326+andystaples@users.noreply.github.com> * Update broken test * Add non-dts tests, update pagination test * non-dts test fix * Lint * Remove unnecesary filter * PR Feedback * Lint * Lint --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
1 parent 3a3c0c4 commit c658a52

5 files changed

Lines changed: 697 additions & 16 deletions

File tree

durabletask/client.py

Lines changed: 171 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from dataclasses import dataclass
77
from datetime import datetime, timezone
88
from enum import Enum
9-
from typing import Any, Optional, Sequence, TypeVar, Union
9+
from typing import Any, List, Optional, Sequence, TypeVar, Union
1010

1111
import grpc
12-
from google.protobuf import wrappers_pb2
1312

1413
from durabletask.entities import EntityInstanceId
1514
from durabletask.entities.entity_metadata import EntityMetadata
@@ -57,6 +56,39 @@ def raise_if_failed(self):
5756
self.failure_details)
5857

5958

59+
@dataclass
60+
class OrchestrationQuery:
61+
created_time_from: Optional[datetime] = None
62+
created_time_to: Optional[datetime] = None
63+
runtime_status: Optional[List[OrchestrationStatus]] = None
64+
# Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
65+
# results instead.
66+
max_instance_count: Optional[int] = (1 << 31) - 1
67+
fetch_inputs_and_outputs: bool = False
68+
69+
70+
@dataclass
71+
class EntityQuery:
72+
instance_id_starts_with: Optional[str] = None
73+
last_modified_from: Optional[datetime] = None
74+
last_modified_to: Optional[datetime] = None
75+
include_state: bool = True
76+
include_transient: bool = False
77+
page_size: Optional[int] = None
78+
79+
80+
@dataclass
81+
class PurgeInstancesResult:
82+
deleted_instance_count: int
83+
is_complete: bool
84+
85+
86+
@dataclass
87+
class CleanEntityStorageResult:
88+
empty_entities_removed: int
89+
orphaned_locks_released: int
90+
91+
6092
class OrchestrationFailedError(Exception):
6193
def __init__(self, message: str, failure_details: task.FailureDetails):
6294
super().__init__(message)
@@ -73,6 +105,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
73105

74106
state = res.orchestrationState
75107

108+
new_state = parse_orchestration_state(state)
109+
new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
110+
return new_state
111+
112+
113+
def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
76114
failure_details = None
77115
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
78116
failure_details = task.FailureDetails(
@@ -81,7 +119,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
81119
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)
82120

83121
return OrchestrationState(
84-
instance_id,
122+
state.instanceId,
85123
state.name,
86124
OrchestrationStatus(state.orchestrationStatus),
87125
state.createdTimestamp.ToDatetime(),
@@ -93,7 +131,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
93131

94132

95133
class TaskHubGrpcClient:
96-
97134
def __init__(self, *,
98135
host_address: Optional[str] = None,
99136
metadata: Optional[list[tuple[str, str]]] = None,
@@ -136,7 +173,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
136173
req = pb.CreateInstanceRequest(
137174
name=name,
138175
instanceId=instance_id if instance_id else uuid.uuid4().hex,
139-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
176+
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
140177
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
141178
version=helpers.get_string_value(version if version else self.default_version),
142179
orchestrationIdReusePolicy=reuse_id_policy,
@@ -152,6 +189,42 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
152189
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
153190
return new_orchestration_state(req.instanceId, res)
154191

192+
def get_all_orchestration_states(self,
193+
orchestration_query: Optional[OrchestrationQuery] = None
194+
) -> List[OrchestrationState]:
195+
if orchestration_query is None:
196+
orchestration_query = OrchestrationQuery()
197+
_continuation_token = None
198+
199+
self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
200+
201+
states = []
202+
203+
while True:
204+
req = pb.QueryInstancesRequest(
205+
query=pb.InstanceQuery(
206+
runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
207+
createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
208+
createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
209+
maxInstanceCount=orchestration_query.max_instance_count,
210+
fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
211+
continuationToken=_continuation_token
212+
)
213+
)
214+
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
215+
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
216+
# Check the value for continuationToken - none or "0" indicates that there are no more results.
217+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
218+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
219+
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
220+
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
221+
break
222+
_continuation_token = resp.continuationToken
223+
else:
224+
break
225+
226+
return states
227+
155228
def wait_for_orchestration_start(self, instance_id: str, *,
156229
fetch_payloads: bool = False,
157230
timeout: int = 60) -> Optional[OrchestrationState]:
@@ -199,7 +272,8 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
199272
req = pb.RaiseEventRequest(
200273
instanceId=instance_id,
201274
name=event_name,
202-
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
275+
input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
276+
)
203277

204278
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
205279
self._stub.RaiseEvent(req)
@@ -209,7 +283,7 @@ def terminate_orchestration(self, instance_id: str, *,
209283
recursive: bool = True):
210284
req = pb.TerminateRequest(
211285
instanceId=instance_id,
212-
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
286+
output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
213287
recursive=recursive)
214288

215289
self._logger.info(f"Terminating instance '{instance_id}'.")
@@ -225,10 +299,31 @@ def resume_orchestration(self, instance_id: str):
225299
self._logger.info(f"Resuming instance '{instance_id}'.")
226300
self._stub.ResumeInstance(req)
227301

228-
def purge_orchestration(self, instance_id: str, recursive: bool = True):
302+
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
229303
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
230304
self._logger.info(f"Purging instance '{instance_id}'.")
231-
self._stub.PurgeInstances(req)
305+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
306+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
307+
308+
def purge_orchestrations_by(self,
309+
created_time_from: Optional[datetime] = None,
310+
created_time_to: Optional[datetime] = None,
311+
runtime_status: Optional[List[OrchestrationStatus]] = None,
312+
recursive: bool = False) -> PurgeInstancesResult:
313+
self._logger.info("Purging orchestrations by filter: "
314+
f"created_time_from={created_time_from}, "
315+
f"created_time_to={created_time_to}, "
316+
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
317+
f"recursive={recursive}")
318+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
319+
purgeInstanceFilter=pb.PurgeInstanceFilter(
320+
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
321+
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
322+
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
323+
),
324+
recursive=recursive
325+
))
326+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
232327

233328
def signal_entity(self,
234329
entity_instance_id: EntityInstanceId,
@@ -237,7 +332,7 @@ def signal_entity(self,
237332
req = pb.SignalEntityRequest(
238333
instanceId=str(entity_instance_id),
239334
name=operation_name,
240-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
335+
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
241336
requestId=str(uuid.uuid4()),
242337
scheduledTime=None,
243338
parentTraceContext=None,
@@ -256,4 +351,69 @@ def get_entity(self,
256351
if not res.exists:
257352
return None
258353

259-
return EntityMetadata.from_entity_response(res, include_state)
354+
return EntityMetadata.from_entity_metadata(res.entity, include_state)
355+
356+
def get_all_entities(self,
357+
entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
358+
if entity_query is None:
359+
entity_query = EntityQuery()
360+
_continuation_token = None
361+
362+
self._logger.info(f"Retrieving entities by filter: {entity_query}")
363+
364+
entities = []
365+
366+
while True:
367+
query_request = pb.QueryEntitiesRequest(
368+
query=pb.EntityQuery(
369+
instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
370+
lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
371+
lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
372+
includeState=entity_query.include_state,
373+
includeTransient=entity_query.include_transient,
374+
pageSize=helpers.get_int_value(entity_query.page_size),
375+
continuationToken=_continuation_token
376+
)
377+
)
378+
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
379+
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
380+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
381+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
382+
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
383+
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
384+
break
385+
_continuation_token = resp.continuationToken
386+
else:
387+
break
388+
return entities
389+
390+
def clean_entity_storage(self,
391+
remove_empty_entities: bool = True,
392+
release_orphaned_locks: bool = True
393+
) -> CleanEntityStorageResult:
394+
self._logger.info("Cleaning entity storage")
395+
396+
empty_entities_removed = 0
397+
orphaned_locks_released = 0
398+
_continuation_token = None
399+
400+
while True:
401+
req = pb.CleanEntityStorageRequest(
402+
removeEmptyEntities=remove_empty_entities,
403+
releaseOrphanedLocks=release_orphaned_locks,
404+
continuationToken=_continuation_token
405+
)
406+
resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req)
407+
empty_entities_removed += resp.emptyEntitiesRemoved
408+
orphaned_locks_released += resp.orphanedLocksReleased
409+
410+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
411+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...")
412+
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
413+
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
414+
break
415+
_continuation_token = resp.continuationToken
416+
else:
417+
break
418+
419+
return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)

durabletask/entities/entity_metadata.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,22 @@ def __init__(self,
4444

4545
@staticmethod
4646
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
47+
return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state)
48+
49+
@staticmethod
50+
def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool):
4751
try:
48-
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
52+
entity_id = EntityInstanceId.parse(entity.instanceId)
4953
except ValueError:
5054
raise ValueError("Invalid entity instance ID in entity response.")
5155
entity_state = None
5256
if includes_state:
53-
entity_state = entity_response.entity.serializedState.value
57+
entity_state = entity.serializedState.value
5458
return EntityMetadata(
5559
id=entity_id,
56-
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
57-
backlog_queue_size=entity_response.entity.backlogQueueSize,
58-
locked_by=entity_response.entity.lockedBy.value,
60+
last_modified=entity.lastModifiedTime.ToDatetime(timezone.utc),
61+
backlog_queue_size=entity.backlogQueueSize,
62+
locked_by=entity.lockedBy.value,
5963
includes_state=includes_state,
6064
state=entity_state
6165
)

durabletask/internal/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
184184
return wrappers_pb2.StringValue(value=val)
185185

186186

187+
def get_int_value(val: Optional[int]) -> Optional[wrappers_pb2.Int32Value]:
188+
if val is None:
189+
return None
190+
else:
191+
return wrappers_pb2.Int32Value(value=val)
192+
193+
187194
def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
188195
if val is None:
189196
return wrappers_pb2.StringValue(value="")

0 commit comments

Comments
 (0)