66from dataclasses import dataclass
77from datetime import datetime , timezone
88from enum import Enum
9- from typing import Any , Optional , Sequence , TypeVar , Union
9+ from typing import Any , List , Optional , Sequence , TypeVar , Union
1010
1111import grpc
12- from google .protobuf import wrappers_pb2
1312
1413from durabletask .entities import EntityInstanceId
1514from durabletask .entities .entity_metadata import EntityMetadata
@@ -57,6 +56,39 @@ def raise_if_failed(self):
5756 self .failure_details )
5857
5958
59+ @dataclass
60+ class OrchestrationQuery :
61+ created_time_from : Optional [datetime ] = None
62+ created_time_to : Optional [datetime ] = None
63+ runtime_status : Optional [List [OrchestrationStatus ]] = None
64+ # Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
65+ # results instead.
66+ max_instance_count : Optional [int ] = (1 << 31 ) - 1
67+ fetch_inputs_and_outputs : bool = False
68+
69+
70+ @dataclass
71+ class EntityQuery :
72+ instance_id_starts_with : Optional [str ] = None
73+ last_modified_from : Optional [datetime ] = None
74+ last_modified_to : Optional [datetime ] = None
75+ include_state : bool = True
76+ include_transient : bool = False
77+ page_size : Optional [int ] = None
78+
79+
80+ @dataclass
81+ class PurgeInstancesResult :
82+ deleted_instance_count : int
83+ is_complete : bool
84+
85+
86+ @dataclass
87+ class CleanEntityStorageResult :
88+ empty_entities_removed : int
89+ orphaned_locks_released : int
90+
91+
6092class OrchestrationFailedError (Exception ):
6193 def __init__ (self , message : str , failure_details : task .FailureDetails ):
6294 super ().__init__ (message )
@@ -73,6 +105,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
73105
74106 state = res .orchestrationState
75107
108+ new_state = parse_orchestration_state (state )
109+ new_state .instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
110+ return new_state
111+
112+
113+ def parse_orchestration_state (state : pb .OrchestrationState ) -> OrchestrationState :
76114 failure_details = None
77115 if state .failureDetails .errorMessage != '' or state .failureDetails .errorType != '' :
78116 failure_details = task .FailureDetails (
@@ -81,7 +119,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
81119 state .failureDetails .stackTrace .value if not helpers .is_empty (state .failureDetails .stackTrace ) else None )
82120
83121 return OrchestrationState (
84- instance_id ,
122+ state . instanceId ,
85123 state .name ,
86124 OrchestrationStatus (state .orchestrationStatus ),
87125 state .createdTimestamp .ToDatetime (),
@@ -93,7 +131,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
93131
94132
95133class TaskHubGrpcClient :
96-
97134 def __init__ (self , * ,
98135 host_address : Optional [str ] = None ,
99136 metadata : Optional [list [tuple [str , str ]]] = None ,
@@ -136,7 +173,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
136173 req = pb .CreateInstanceRequest (
137174 name = name ,
138175 instanceId = instance_id if instance_id else uuid .uuid4 ().hex ,
139- input = wrappers_pb2 . StringValue ( value = shared .to_json (input )) if input is not None else None ,
176+ input = helpers . get_string_value ( shared .to_json (input ) if input is not None else None ) ,
140177 scheduledStartTimestamp = helpers .new_timestamp (start_at ) if start_at else None ,
141178 version = helpers .get_string_value (version if version else self .default_version ),
142179 orchestrationIdReusePolicy = reuse_id_policy ,
@@ -152,6 +189,42 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
152189 res : pb .GetInstanceResponse = self ._stub .GetInstance (req )
153190 return new_orchestration_state (req .instanceId , res )
154191
192+ def get_all_orchestration_states (self ,
193+ orchestration_query : Optional [OrchestrationQuery ] = None
194+ ) -> List [OrchestrationState ]:
195+ if orchestration_query is None :
196+ orchestration_query = OrchestrationQuery ()
197+ _continuation_token = None
198+
199+ self ._logger .info (f"Querying orchestration instances with query: { orchestration_query } " )
200+
201+ states = []
202+
203+ while True :
204+ req = pb .QueryInstancesRequest (
205+ query = pb .InstanceQuery (
206+ runtimeStatus = [status .value for status in orchestration_query .runtime_status ] if orchestration_query .runtime_status else None ,
207+ createdTimeFrom = helpers .new_timestamp (orchestration_query .created_time_from ) if orchestration_query .created_time_from else None ,
208+ createdTimeTo = helpers .new_timestamp (orchestration_query .created_time_to ) if orchestration_query .created_time_to else None ,
209+ maxInstanceCount = orchestration_query .max_instance_count ,
210+ fetchInputsAndOutputs = orchestration_query .fetch_inputs_and_outputs ,
211+ continuationToken = _continuation_token
212+ )
213+ )
214+ resp : pb .QueryInstancesResponse = self ._stub .QueryInstances (req )
215+ states += [parse_orchestration_state (res ) for res in resp .orchestrationState ]
216+ # Check the value for continuationToken - none or "0" indicates that there are no more results.
217+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
218+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , fetching next list of instances..." )
219+ if _continuation_token and _continuation_token .value and _continuation_token .value == resp .continuationToken .value :
220+ self ._logger .warning (f"Received the same continuation token value { resp .continuationToken .value } again, stopping to avoid infinite loop." )
221+ break
222+ _continuation_token = resp .continuationToken
223+ else :
224+ break
225+
226+ return states
227+
155228 def wait_for_orchestration_start (self , instance_id : str , * ,
156229 fetch_payloads : bool = False ,
157230 timeout : int = 60 ) -> Optional [OrchestrationState ]:
@@ -199,7 +272,8 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
199272 req = pb .RaiseEventRequest (
200273 instanceId = instance_id ,
201274 name = event_name ,
202- input = wrappers_pb2 .StringValue (value = shared .to_json (data )) if data else None )
275+ input = helpers .get_string_value (shared .to_json (data ) if data is not None else None )
276+ )
203277
204278 self ._logger .info (f"Raising event '{ event_name } ' for instance '{ instance_id } '." )
205279 self ._stub .RaiseEvent (req )
@@ -209,7 +283,7 @@ def terminate_orchestration(self, instance_id: str, *,
209283 recursive : bool = True ):
210284 req = pb .TerminateRequest (
211285 instanceId = instance_id ,
212- output = wrappers_pb2 . StringValue ( value = shared .to_json (output )) if output else None ,
286+ output = helpers . get_string_value ( shared .to_json (output ) if output is not None else None ) ,
213287 recursive = recursive )
214288
215289 self ._logger .info (f"Terminating instance '{ instance_id } '." )
@@ -225,10 +299,31 @@ def resume_orchestration(self, instance_id: str):
225299 self ._logger .info (f"Resuming instance '{ instance_id } '." )
226300 self ._stub .ResumeInstance (req )
227301
228- def purge_orchestration (self , instance_id : str , recursive : bool = True ):
302+ def purge_orchestration (self , instance_id : str , recursive : bool = True ) -> PurgeInstancesResult :
229303 req = pb .PurgeInstancesRequest (instanceId = instance_id , recursive = recursive )
230304 self ._logger .info (f"Purging instance '{ instance_id } '." )
231- self ._stub .PurgeInstances (req )
305+ resp : pb .PurgeInstancesResponse = self ._stub .PurgeInstances (req )
306+ return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
307+
308+ def purge_orchestrations_by (self ,
309+ created_time_from : Optional [datetime ] = None ,
310+ created_time_to : Optional [datetime ] = None ,
311+ runtime_status : Optional [List [OrchestrationStatus ]] = None ,
312+ recursive : bool = False ) -> PurgeInstancesResult :
313+ self ._logger .info ("Purging orchestrations by filter: "
314+ f"created_time_from={ created_time_from } , "
315+ f"created_time_to={ created_time_to } , "
316+ f"runtime_status={ [str (status ) for status in runtime_status ] if runtime_status else None } , "
317+ f"recursive={ recursive } " )
318+ resp : pb .PurgeInstancesResponse = self ._stub .PurgeInstances (pb .PurgeInstancesRequest (
319+ purgeInstanceFilter = pb .PurgeInstanceFilter (
320+ createdTimeFrom = helpers .new_timestamp (created_time_from ) if created_time_from else None ,
321+ createdTimeTo = helpers .new_timestamp (created_time_to ) if created_time_to else None ,
322+ runtimeStatus = [status .value for status in runtime_status ] if runtime_status else None
323+ ),
324+ recursive = recursive
325+ ))
326+ return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
232327
233328 def signal_entity (self ,
234329 entity_instance_id : EntityInstanceId ,
@@ -237,7 +332,7 @@ def signal_entity(self,
237332 req = pb .SignalEntityRequest (
238333 instanceId = str (entity_instance_id ),
239334 name = operation_name ,
240- input = wrappers_pb2 . StringValue ( value = shared .to_json (input )) if input else None ,
335+ input = helpers . get_string_value ( shared .to_json (input ) if input is not None else None ) ,
241336 requestId = str (uuid .uuid4 ()),
242337 scheduledTime = None ,
243338 parentTraceContext = None ,
@@ -256,4 +351,69 @@ def get_entity(self,
256351 if not res .exists :
257352 return None
258353
259- return EntityMetadata .from_entity_response (res , include_state )
354+ return EntityMetadata .from_entity_metadata (res .entity , include_state )
355+
356+ def get_all_entities (self ,
357+ entity_query : Optional [EntityQuery ] = None ) -> List [EntityMetadata ]:
358+ if entity_query is None :
359+ entity_query = EntityQuery ()
360+ _continuation_token = None
361+
362+ self ._logger .info (f"Retrieving entities by filter: { entity_query } " )
363+
364+ entities = []
365+
366+ while True :
367+ query_request = pb .QueryEntitiesRequest (
368+ query = pb .EntityQuery (
369+ instanceIdStartsWith = helpers .get_string_value (entity_query .instance_id_starts_with ),
370+ lastModifiedFrom = helpers .new_timestamp (entity_query .last_modified_from ) if entity_query .last_modified_from else None ,
371+ lastModifiedTo = helpers .new_timestamp (entity_query .last_modified_to ) if entity_query .last_modified_to else None ,
372+ includeState = entity_query .include_state ,
373+ includeTransient = entity_query .include_transient ,
374+ pageSize = helpers .get_int_value (entity_query .page_size ),
375+ continuationToken = _continuation_token
376+ )
377+ )
378+ resp : pb .QueryEntitiesResponse = self ._stub .QueryEntities (query_request )
379+ entities += [EntityMetadata .from_entity_metadata (entity , query_request .query .includeState ) for entity in resp .entities ]
380+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
381+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , fetching next page of entities..." )
382+ if _continuation_token and _continuation_token .value and _continuation_token .value == resp .continuationToken .value :
383+ self ._logger .warning (f"Received the same continuation token value { resp .continuationToken .value } again, stopping to avoid infinite loop." )
384+ break
385+ _continuation_token = resp .continuationToken
386+ else :
387+ break
388+ return entities
389+
390+ def clean_entity_storage (self ,
391+ remove_empty_entities : bool = True ,
392+ release_orphaned_locks : bool = True
393+ ) -> CleanEntityStorageResult :
394+ self ._logger .info ("Cleaning entity storage" )
395+
396+ empty_entities_removed = 0
397+ orphaned_locks_released = 0
398+ _continuation_token = None
399+
400+ while True :
401+ req = pb .CleanEntityStorageRequest (
402+ removeEmptyEntities = remove_empty_entities ,
403+ releaseOrphanedLocks = release_orphaned_locks ,
404+ continuationToken = _continuation_token
405+ )
406+ resp : pb .CleanEntityStorageResponse = self ._stub .CleanEntityStorage (req )
407+ empty_entities_removed += resp .emptyEntitiesRemoved
408+ orphaned_locks_released += resp .orphanedLocksReleased
409+
410+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
411+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , cleaning next page..." )
412+ if _continuation_token and _continuation_token .value and _continuation_token .value == resp .continuationToken .value :
413+ self ._logger .warning (f"Received the same continuation token value { resp .continuationToken .value } again, stopping to avoid infinite loop." )
414+ break
415+ _continuation_token = resp .continuationToken
416+ else :
417+ break
418+
419+ return CleanEntityStorageResult (empty_entities_removed , orphaned_locks_released )
0 commit comments