Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/aws_durable_execution_sdk_python/concurrency/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,13 @@ def _execute_item_in_child_context(
and execution-order invariant.
"""

operation_id: str = executor_context._create_step_id_for_logical_step( # noqa: SLF001
executable.index
is_virtual: bool = self.nesting_type is NestingType.FLAT
operation_id: str = (
executor_context._operation_id_generator.create_step_id_for_logical_step( # noqa: SLF001
executable.index, is_virtual=is_virtual
)
)
name: str = self.get_iteration_name(executable.index)
is_virtual: bool = self.nesting_type is NestingType.FLAT

child_context: DurableContext = executor_context.create_child_context(
operation_id, is_virtual=is_virtual
Expand Down Expand Up @@ -447,7 +449,6 @@ def run_in_child_handler() -> ResultType:
is_virtual=is_virtual,
),
)
child_context.state.track_replay(operation_id=operation_id)
return result

def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
Expand All @@ -458,10 +459,11 @@ def replay(self, execution_state: ExecutionState, executor_context: DurableConte
This will pre-generate all the operation ids for the children and collect the checkpointed
results.
"""
is_virtual: bool = self.nesting_type is NestingType.FLAT
items: list[BatchItem[ResultType]] = []
for executable in self.executables:
operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
executable.index
operation_id = executor_context._operation_id_generator.create_step_id_for_logical_step( # noqa: SLF001
executable.index, is_virtual=is_virtual
)
checkpoint = execution_state.get_checkpoint_result(operation_id)

Expand Down
131 changes: 84 additions & 47 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
SerDes,
deserialize,
)
from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001
from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus # noqa: TCH001
from aws_durable_execution_sdk_python.threading import OrderedCounter
from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol
from aws_durable_execution_sdk_python.types import (
Expand Down Expand Up @@ -277,6 +277,42 @@ def result(self) -> T | None:
raise SuspendExecution(msg)


class OperationIdGenerator:
def __init__(self, step_id_prefix: str | None, parent_id: str | None) -> None:
self._operation_counter: OrderedCounter = OrderedCounter()
self._virtual_operation_counter: OrderedCounter = OrderedCounter()
# child operations use this to generate deterministic step ids.
# differs from `parent_id` only for virtual contexts.
self._step_id_prefix: str | None = (
step_id_prefix if step_id_prefix is not None else parent_id
)

def peek_next_step_id(self):
next_step = self._operation_counter.get_current() + 1
return self.create_step_id_for_logical_step(next_step, is_virtual=False)

def create_step_id(self, is_virtual: bool = False) -> str:
"""Generate a thread-safe step id, incrementing in order of invocation.

This method is an internal implementation detail. Do not rely the exact format of
the id generated by this method. It is subject to change without notice.
"""
new_counter: int = (
self._virtual_operation_counter if is_virtual else self._operation_counter
).increment()
return self.create_step_id_for_logical_step(new_counter, is_virtual=is_virtual)

def create_step_id_for_logical_step(self, step: int, is_virtual: bool) -> str:
"""
Generate a step_id based on the given logical step.
This allows us to recover operation ids or even look
forward without changing the internal state of this context.
"""
parts = [self._step_id_prefix, "v" if is_virtual else None, step]
step_id: str = "-".join([str(part) for part in parts if part])
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]


class DurableContext(DurableContextProtocol):
def __init__(
self,
Expand All @@ -286,29 +322,30 @@ def __init__(
parent_id: str | None = None,
logger: Logger | None = None,
step_id_prefix: str | None = None,
replay_status: ReplayStatus = ReplayStatus.REPLAY,
) -> None:
self.state: ExecutionState = state
self.execution_context: ExecutionContext = execution_context
self.lambda_context = lambda_context
# operations inside this context use this id as their parent
self._parent_id: str | None = parent_id
# child operations use this to generate deterministic step ids.
# differs from `parent_id` only for virtual contexts.
self._step_id_prefix: str | None = (
step_id_prefix if step_id_prefix is not None else parent_id
self._is_virtual: bool = (
step_id_prefix is not None and parent_id != step_id_prefix
)
# cached at construction to make invariant even if parent/prefix mutates.
self._is_virtual: bool = self._parent_id != self._step_id_prefix
self._step_counter: OrderedCounter = OrderedCounter()
self._operation_id_generator: OperationIdGenerator = OperationIdGenerator(
step_id_prefix, parent_id
)
self._replay_status: ReplayStatus = replay_status
self._track_replay()

log_info = LogInfo(
execution_state=state,
parent_id=parent_id,
)
self._log_info = log_info
self.logger: Logger = logger or Logger.from_log_info(
logger=logging.getLogger(),
info=log_info,
context=self,
)

@property
Expand All @@ -323,6 +360,11 @@ def is_virtual(self) -> bool:
"""
return self._is_virtual

@property
def is_replaying(self) -> bool:
"""True if this context is in replay mode"""
return self._replay_status is ReplayStatus.REPLAY

# region factories
@staticmethod
def from_lambda_context(
Expand Down Expand Up @@ -371,9 +413,9 @@ def create_child_context(
lambda_context=self.lambda_context,
parent_id=child_parent_id,
step_id_prefix=operation_id,
replay_status=self._replay_status,
logger=self.logger.with_log_info(
LogInfo(
execution_state=self.state,
parent_id=child_parent_id,
)
),
Expand All @@ -396,26 +438,20 @@ def set_logger(self, new_logger: LoggerInterface):
self.logger = Logger.from_log_info(
logger=new_logger,
info=self._log_info,
context=self,
)

def _create_step_id_for_logical_step(self, step: int) -> str:
"""
Generate a step_id based on the given logical step.
This allows us to recover operation ids or even look
forward without changing the internal state of this context.
"""
prefix: str | None = self._step_id_prefix
step_id: str = f"{prefix}-{step}" if prefix else str(step)
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]

def _create_step_id(self) -> str:
"""Generate a thread-safe step id, incrementing in order of invocation.

This method is an internal implementation detail. Do not rely the exact format of
the id generated by this method. It is subject to change without notice.
"""
new_counter: int = self._step_counter.increment()
return self._create_step_id_for_logical_step(new_counter)
def _track_replay(self) -> None:
"""Transition replay status to NEW if the next operation has not been checkpointed"""
if self._replay_status is ReplayStatus.NEW:
return
# check if next operation exists
next_step_id = self._operation_id_generator.peek_next_step_id()
if not self.state.get_checkpoint_result(next_step_id).is_existent():
Comment thread
zhongkechen marked this conversation as resolved.
# update the context replay status to NEW
self._replay_status = ReplayStatus.NEW
# update the execution replay status to NEW
self.state.transition_replay_status()
Comment thread
zhongkechen marked this conversation as resolved.

# region Operations

Expand All @@ -438,7 +474,7 @@ def create_callback(
"""
if not config:
config = CallbackConfig()
operation_id: str = self._create_step_id()
operation_id: str = self._operation_id_generator.create_step_id()
executor: CallbackOperationExecutor = CallbackOperationExecutor(
state=self.state,
operation_identifier=OperationIdentifier(
Expand All @@ -448,14 +484,14 @@ def create_callback(
),
config=config,
)
self._track_replay()
callback_id: str = executor.process()
result: Callback = Callback(
callback_id=callback_id,
operation_id=operation_id,
state=self.state,
serdes=config.serdes,
)
self.state.track_replay(operation_id=operation_id)
return result

def invoke(
Expand All @@ -478,7 +514,7 @@ def invoke(
"""
if not config:
config = InvokeConfig[P, R]()
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
executor: InvokeOperationExecutor[R] = InvokeOperationExecutor(
function_name=function_name,
payload=payload,
Expand All @@ -490,8 +526,8 @@ def invoke(
),
config=config,
)
self._track_replay()
result: R = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

def map(
Expand All @@ -504,7 +540,7 @@ def map(
"""Execute a callable for each item in parallel."""
map_name: str | None = self._resolve_step_name(name, func)

operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
operation_identifier = OperationIdentifier(
operation_id=operation_id,
parent_id=self._parent_id,
Expand All @@ -526,6 +562,7 @@ def map_in_child_context() -> BatchResult[R]:
operation_identifier=operation_identifier,
)

self._track_replay()
result: BatchResult[R] = child_handler(
func=map_in_child_context,
state=self.state,
Expand All @@ -539,7 +576,6 @@ def map_in_child_context() -> BatchResult[R]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def parallel(
Expand All @@ -550,7 +586,7 @@ def parallel(
) -> BatchResult[T]:
"""Execute multiple callables in parallel."""
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
parallel_context = self.create_child_context(operation_id=operation_id)
operation_identifier = OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=name
Expand All @@ -569,6 +605,7 @@ def parallel_in_child_context() -> BatchResult[T]:
operation_identifier=operation_identifier,
)

self._track_replay()
result: BatchResult[T] = child_handler(
func=parallel_in_child_context,
state=self.state,
Expand All @@ -582,7 +619,6 @@ def parallel_in_child_context() -> BatchResult[T]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def run_in_child_context(
Expand All @@ -596,18 +632,19 @@ def run_in_child_context(
Use this to nest and group operations.

Args:
callable (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it.
func (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it.
name (str | None): name for the operation.
config (ChildConfig | None = None): c

Returns:
T: The result of the callable.
"""
step_name: str | None = self._resolve_step_name(name, func)
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()

is_virtual: bool = config.is_virtual if config else False
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._operation_id_generator.create_step_id(
is_virtual=is_virtual
)

def callable_with_child_context():
return func(
Expand All @@ -616,6 +653,7 @@ def callable_with_child_context():
)
)

self._track_replay()
result: T = child_handler(
func=callable_with_child_context,
state=self.state,
Expand All @@ -626,7 +664,6 @@ def callable_with_child_context():
),
config=config,
)
self.state.track_replay(operation_id=operation_id)
return result

def step(
Expand All @@ -639,7 +676,7 @@ def step(
logger.debug("Step name: %s", step_name)
if not config:
config = StepConfig()
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
executor: StepOperationExecutor[T] = StepOperationExecutor(
func=func,
config=config,
Expand All @@ -651,8 +688,8 @@ def step(
),
context_logger=self.logger,
)
self._track_replay()
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

def wait(self, duration: Duration, name: str | None = None) -> None:
Expand All @@ -666,7 +703,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
if seconds < 1:
msg = "duration must be at least 1 second"
raise ValidationError(msg)
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
wait_seconds = duration.seconds
executor: WaitOperationExecutor = WaitOperationExecutor(
seconds=wait_seconds,
Expand All @@ -677,8 +714,8 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
name=name,
),
)
self._track_replay()
executor.process()
self.state.track_replay(operation_id=operation_id)

def wait_for_callback(
self,
Expand Down Expand Up @@ -720,7 +757,7 @@ def wait_for_condition(
msg = "`config` is required for wait_for_condition"
raise ValidationError(msg)

operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
executor: WaitForConditionOperationExecutor[T] = (
WaitForConditionOperationExecutor(
check=check,
Expand All @@ -734,8 +771,8 @@ def wait_for_condition(
context_logger=self.logger,
)
)
self._track_replay()
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result


Expand Down
Loading