diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index 3a7ab136..1ebb3cc3 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -415,11 +415,13 @@ def _execute_item_in_child_context( and execution-order invariant. """ - operation_id: str = executor_context._create_step_id_for_logical_step( # noqa: SLF001 - executable.index + is_virtual: bool = self.nesting_type is NestingType.FLAT + operation_id: str = ( + executor_context._operation_id_generator.create_step_id_for_logical_step( # noqa: SLF001 + executable.index, is_virtual=is_virtual + ) ) name: str = self.get_iteration_name(executable.index) - is_virtual: bool = self.nesting_type is NestingType.FLAT child_context: DurableContext = executor_context.create_child_context( operation_id, is_virtual=is_virtual @@ -447,7 +449,6 @@ def run_in_child_handler() -> ResultType: is_virtual=is_virtual, ), ) - child_context.state.track_replay(operation_id=operation_id) return result def replay(self, execution_state: ExecutionState, executor_context: DurableContext): @@ -458,10 +459,11 @@ def replay(self, execution_state: ExecutionState, executor_context: DurableConte This will pre-generate all the operation ids for the children and collect the checkpointed results. """ + is_virtual: bool = self.nesting_type is NestingType.FLAT items: list[BatchItem[ResultType]] = [] for executable in self.executables: - operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001 - executable.index + operation_id = executor_context._operation_id_generator.create_step_id_for_logical_step( # noqa: SLF001 + executable.index, is_virtual=is_virtual ) checkpoint = execution_state.get_checkpoint_result(operation_id) diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 6691f2ab..8faed408 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -43,7 +43,7 @@ SerDes, deserialize, ) -from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001 +from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus # noqa: TCH001 from aws_durable_execution_sdk_python.threading import OrderedCounter from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol from aws_durable_execution_sdk_python.types import ( @@ -277,6 +277,42 @@ def result(self) -> T | None: raise SuspendExecution(msg) +class OperationIdGenerator: + def __init__(self, step_id_prefix: str | None, parent_id: str | None) -> None: + self._operation_counter: OrderedCounter = OrderedCounter() + self._virtual_operation_counter: OrderedCounter = OrderedCounter() + # child operations use this to generate deterministic step ids. + # differs from `parent_id` only for virtual contexts. + self._step_id_prefix: str | None = ( + step_id_prefix if step_id_prefix is not None else parent_id + ) + + def peek_next_step_id(self): + next_step = self._operation_counter.get_current() + 1 + return self.create_step_id_for_logical_step(next_step, is_virtual=False) + + def create_step_id(self, is_virtual: bool = False) -> str: + """Generate a thread-safe step id, incrementing in order of invocation. + + This method is an internal implementation detail. Do not rely the exact format of + the id generated by this method. It is subject to change without notice. + """ + new_counter: int = ( + self._virtual_operation_counter if is_virtual else self._operation_counter + ).increment() + return self.create_step_id_for_logical_step(new_counter, is_virtual=is_virtual) + + def create_step_id_for_logical_step(self, step: int, is_virtual: bool) -> str: + """ + Generate a step_id based on the given logical step. + This allows us to recover operation ids or even look + forward without changing the internal state of this context. + """ + parts = [self._step_id_prefix, "v" if is_virtual else None, step] + step_id: str = "-".join([str(part) for part in parts if part]) + return hashlib.blake2b(step_id.encode()).hexdigest()[:64] + + class DurableContext(DurableContextProtocol): def __init__( self, @@ -286,29 +322,30 @@ def __init__( parent_id: str | None = None, logger: Logger | None = None, step_id_prefix: str | None = None, + replay_status: ReplayStatus = ReplayStatus.REPLAY, ) -> None: self.state: ExecutionState = state self.execution_context: ExecutionContext = execution_context self.lambda_context = lambda_context # operations inside this context use this id as their parent self._parent_id: str | None = parent_id - # child operations use this to generate deterministic step ids. - # differs from `parent_id` only for virtual contexts. - self._step_id_prefix: str | None = ( - step_id_prefix if step_id_prefix is not None else parent_id + self._is_virtual: bool = ( + step_id_prefix is not None and parent_id != step_id_prefix ) - # cached at construction to make invariant even if parent/prefix mutates. - self._is_virtual: bool = self._parent_id != self._step_id_prefix - self._step_counter: OrderedCounter = OrderedCounter() + self._operation_id_generator: OperationIdGenerator = OperationIdGenerator( + step_id_prefix, parent_id + ) + self._replay_status: ReplayStatus = replay_status + self._track_replay() log_info = LogInfo( - execution_state=state, parent_id=parent_id, ) self._log_info = log_info self.logger: Logger = logger or Logger.from_log_info( logger=logging.getLogger(), info=log_info, + context=self, ) @property @@ -323,6 +360,11 @@ def is_virtual(self) -> bool: """ return self._is_virtual + @property + def is_replaying(self) -> bool: + """True if this context is in replay mode""" + return self._replay_status is ReplayStatus.REPLAY + # region factories @staticmethod def from_lambda_context( @@ -371,9 +413,9 @@ def create_child_context( lambda_context=self.lambda_context, parent_id=child_parent_id, step_id_prefix=operation_id, + replay_status=self._replay_status, logger=self.logger.with_log_info( LogInfo( - execution_state=self.state, parent_id=child_parent_id, ) ), @@ -396,26 +438,20 @@ def set_logger(self, new_logger: LoggerInterface): self.logger = Logger.from_log_info( logger=new_logger, info=self._log_info, + context=self, ) - def _create_step_id_for_logical_step(self, step: int) -> str: - """ - Generate a step_id based on the given logical step. - This allows us to recover operation ids or even look - forward without changing the internal state of this context. - """ - prefix: str | None = self._step_id_prefix - step_id: str = f"{prefix}-{step}" if prefix else str(step) - return hashlib.blake2b(step_id.encode()).hexdigest()[:64] - - def _create_step_id(self) -> str: - """Generate a thread-safe step id, incrementing in order of invocation. - - This method is an internal implementation detail. Do not rely the exact format of - the id generated by this method. It is subject to change without notice. - """ - new_counter: int = self._step_counter.increment() - return self._create_step_id_for_logical_step(new_counter) + def _track_replay(self) -> None: + """Transition replay status to NEW if the next operation has not been checkpointed""" + if self._replay_status is ReplayStatus.NEW: + return + # check if next operation exists + next_step_id = self._operation_id_generator.peek_next_step_id() + if not self.state.get_checkpoint_result(next_step_id).is_existent(): + # update the context replay status to NEW + self._replay_status = ReplayStatus.NEW + # update the execution replay status to NEW + self.state.transition_replay_status() # region Operations @@ -438,7 +474,7 @@ def create_callback( """ if not config: config = CallbackConfig() - operation_id: str = self._create_step_id() + operation_id: str = self._operation_id_generator.create_step_id() executor: CallbackOperationExecutor = CallbackOperationExecutor( state=self.state, operation_identifier=OperationIdentifier( @@ -448,6 +484,7 @@ def create_callback( ), config=config, ) + self._track_replay() callback_id: str = executor.process() result: Callback = Callback( callback_id=callback_id, @@ -455,7 +492,6 @@ def create_callback( state=self.state, serdes=config.serdes, ) - self.state.track_replay(operation_id=operation_id) return result def invoke( @@ -478,7 +514,7 @@ def invoke( """ if not config: config = InvokeConfig[P, R]() - operation_id = self._create_step_id() + operation_id = self._operation_id_generator.create_step_id() executor: InvokeOperationExecutor[R] = InvokeOperationExecutor( function_name=function_name, payload=payload, @@ -490,8 +526,8 @@ def invoke( ), config=config, ) + self._track_replay() result: R = executor.process() - self.state.track_replay(operation_id=operation_id) return result def map( @@ -504,7 +540,7 @@ def map( """Execute a callable for each item in parallel.""" map_name: str | None = self._resolve_step_name(name, func) - operation_id = self._create_step_id() + operation_id = self._operation_id_generator.create_step_id() operation_identifier = OperationIdentifier( operation_id=operation_id, parent_id=self._parent_id, @@ -526,6 +562,7 @@ def map_in_child_context() -> BatchResult[R]: operation_identifier=operation_identifier, ) + self._track_replay() result: BatchResult[R] = child_handler( func=map_in_child_context, state=self.state, @@ -539,7 +576,6 @@ def map_in_child_context() -> BatchResult[R]: item_serdes=None, ), ) - self.state.track_replay(operation_id=operation_id) return result def parallel( @@ -550,7 +586,7 @@ def parallel( ) -> BatchResult[T]: """Execute multiple callables in parallel.""" # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id - operation_id = self._create_step_id() + operation_id = self._operation_id_generator.create_step_id() parallel_context = self.create_child_context(operation_id=operation_id) operation_identifier = OperationIdentifier( operation_id=operation_id, parent_id=self._parent_id, name=name @@ -569,6 +605,7 @@ def parallel_in_child_context() -> BatchResult[T]: operation_identifier=operation_identifier, ) + self._track_replay() result: BatchResult[T] = child_handler( func=parallel_in_child_context, state=self.state, @@ -582,7 +619,6 @@ def parallel_in_child_context() -> BatchResult[T]: item_serdes=None, ), ) - self.state.track_replay(operation_id=operation_id) return result def run_in_child_context( @@ -596,7 +632,7 @@ def run_in_child_context( Use this to nest and group operations. Args: - callable (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it. + func (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it. name (str | None): name for the operation. config (ChildConfig | None = None): c @@ -604,10 +640,11 @@ def run_in_child_context( T: The result of the callable. """ step_name: str | None = self._resolve_step_name(name, func) - # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id - operation_id = self._create_step_id() - is_virtual: bool = config.is_virtual if config else False + # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id + operation_id = self._operation_id_generator.create_step_id( + is_virtual=is_virtual + ) def callable_with_child_context(): return func( @@ -616,6 +653,7 @@ def callable_with_child_context(): ) ) + self._track_replay() result: T = child_handler( func=callable_with_child_context, state=self.state, @@ -626,7 +664,6 @@ def callable_with_child_context(): ), config=config, ) - self.state.track_replay(operation_id=operation_id) return result def step( @@ -639,7 +676,7 @@ def step( logger.debug("Step name: %s", step_name) if not config: config = StepConfig() - operation_id = self._create_step_id() + operation_id = self._operation_id_generator.create_step_id() executor: StepOperationExecutor[T] = StepOperationExecutor( func=func, config=config, @@ -651,8 +688,8 @@ def step( ), context_logger=self.logger, ) + self._track_replay() result: T = executor.process() - self.state.track_replay(operation_id=operation_id) return result def wait(self, duration: Duration, name: str | None = None) -> None: @@ -666,7 +703,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: if seconds < 1: msg = "duration must be at least 1 second" raise ValidationError(msg) - operation_id = self._create_step_id() + operation_id = self._operation_id_generator.create_step_id() wait_seconds = duration.seconds executor: WaitOperationExecutor = WaitOperationExecutor( seconds=wait_seconds, @@ -677,8 +714,8 @@ def wait(self, duration: Duration, name: str | None = None) -> None: name=name, ), ) + self._track_replay() executor.process() - self.state.track_replay(operation_id=operation_id) def wait_for_callback( self, @@ -720,7 +757,7 @@ def wait_for_condition( msg = "`config` is required for wait_for_condition" raise ValidationError(msg) - operation_id = self._create_step_id() + operation_id = self._operation_id_generator.create_step_id() executor: WaitForConditionOperationExecutor[T] = ( WaitForConditionOperationExecutor( check=check, @@ -734,8 +771,8 @@ def wait_for_condition( context_logger=self.logger, ) ) + self._track_replay() result: T = executor.process() - self.state.track_replay(operation_id=operation_id) return result diff --git a/src/aws_durable_execution_sdk_python/logger.py b/src/aws_durable_execution_sdk_python/logger.py index c2a2be71..e02359db 100644 --- a/src/aws_durable_execution_sdk_python/logger.py +++ b/src/aws_durable_execution_sdk_python/logger.py @@ -10,13 +10,12 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping, MutableMapping - from aws_durable_execution_sdk_python.context import ExecutionState + from aws_durable_execution_sdk_python import DurableContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier @dataclass(frozen=True) class LogInfo: - execution_state: ExecutionState parent_id: str | None = None operation_id: str | None = None name: str | None = None @@ -25,13 +24,11 @@ class LogInfo: @classmethod def from_operation_identifier( cls, - execution_state: ExecutionState, op_id: OperationIdentifier, attempt: int | None = None, ) -> LogInfo: """Create new log info from an execution arn, OperationIdentifier and attempt.""" return cls( - execution_state=execution_state, parent_id=op_id.parent_id, operation_id=op_id.operation_id, name=op_id.name, @@ -41,7 +38,6 @@ def from_operation_identifier( def with_parent_id(self, parent_id: str) -> LogInfo: """Clone the log info with a new parent id.""" return LogInfo( - execution_state=self.execution_state, parent_id=parent_id, operation_id=self.operation_id, name=self.name, @@ -54,17 +50,19 @@ def __init__( self, logger: LoggerInterface, default_extra: Mapping[str, object], - execution_state: ExecutionState, + context: DurableContext, ) -> None: self._logger = logger self._default_extra = default_extra - self._execution_state = execution_state + self._context = context @classmethod - def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: + def from_log_info( + cls, logger: LoggerInterface, info: LogInfo, context: DurableContext + ) -> Logger: """Create a new logger with the given LogInfo.""" extra: MutableMapping[str, object] = { - "executionArn": info.execution_state.durable_execution_arn + "executionArn": context.state.durable_execution_arn } if info.parent_id: extra["parentId"] = info.parent_id @@ -75,15 +73,14 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: extra["attempt"] = info.attempt if info.operation_id: extra["operationId"] = info.operation_id - return cls( - logger=logger, default_extra=extra, execution_state=info.execution_state - ) + return cls(logger=logger, default_extra=extra, context=context) def with_log_info(self, info: LogInfo) -> Logger: """Clone the existing logger with new LogInfo.""" return Logger.from_log_info( logger=self._logger, info=info, + context=self._context, ) def get_logger(self) -> LoggerInterface: @@ -128,4 +125,4 @@ def _log( log_func(msg, *args, extra=merged_extra) def _should_log(self) -> bool: - return not self._execution_state.is_replaying() + return not self._context.is_replaying diff --git a/src/aws_durable_execution_sdk_python/operation/step.py b/src/aws_durable_execution_sdk_python/operation/step.py index 8a418fb3..6517c514 100644 --- a/src/aws_durable_execution_sdk_python/operation/step.py +++ b/src/aws_durable_execution_sdk_python/operation/step.py @@ -210,7 +210,6 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: step_context: StepContext = StepContext( logger=self.context_logger.with_log_info( LogInfo.from_operation_identifier( - execution_state=self.state, op_id=self.operation_identifier, attempt=attempt, ) diff --git a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py index 5c4f1c4c..3f4eaeb8 100644 --- a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py +++ b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py @@ -181,7 +181,6 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: check_context = WaitForConditionCheckContext( logger=self.context_logger.with_log_info( LogInfo.from_operation_identifier( - execution_state=self.state, op_id=self.operation_identifier, attempt=attempt, ) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 83175503..fd639886 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -343,39 +343,12 @@ def get_execution_operation(self) -> Operation | None: return candidate - def track_replay(self, operation_id: str) -> None: - """Check if operation exists with completed status; if not, transition to NEW status. - - This method is called before each operation (step, wait, invoke, etc.) to determine - if we've reached the replay boundary. Once we encounter an operation that doesn't - exist or isn't completed, we transition from REPLAY to NEW status, which enables - logging for all subsequent code. - - Args: - operation_id: The operation ID to check - """ - with self._replay_status_lock: - if self._replay_status == ReplayStatus.REPLAY: - self._visited_operations.add(operation_id) - completed_ops = { - op_id - for op_id, op in self.operations.items() - if op.operation_type != OperationType.EXECUTION - and op.status - in { - OperationStatus.SUCCEEDED, - OperationStatus.FAILED, - OperationStatus.CANCELLED, - OperationStatus.STOPPED, - OperationStatus.TIMED_OUT, - } - } - if completed_ops.issubset(self._visited_operations): - logger.debug( - "Transitioning from REPLAY to NEW status at operation %s", - operation_id, - ) - self._replay_status = ReplayStatus.NEW + def transition_replay_status(self) -> None: + """Transition to NEW status""" + if self._replay_status is ReplayStatus.REPLAY: + with self._replay_status_lock: + logger.debug("Transitioning from REPLAY to NEW status") + self._replay_status = ReplayStatus.NEW def is_replaying(self) -> bool: """Check if execution is currently in replay mode. diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 3d9ae270..748f89b4 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -2556,10 +2556,12 @@ def patched_child_handler( executor_context = Mock() executor_context._parent_id = "parent_123" # noqa SLF001 - def create_step_id(index): + def create_step_id(index, is_virtual): return f"step_{index}" - executor_context._create_step_id_for_logical_step = create_step_id # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + create_step_id # noqa SLF001 + ) def create_child_context(operation_id, *, is_virtual=False): child_ctx = Mock() @@ -2620,12 +2622,12 @@ def mock_get_checkpoint_result(operation_id): mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result - def mock_create_step_id_for_logical_step(step): + def mock_create_step_id_for_logical_step(step, is_virtual): return f"op_{step}" # Mock executor context mock_executor_context = Mock() - mock_executor_context._create_step_id_for_logical_step = ( # noqa + mock_executor_context._operation_id_generator.create_step_id_for_logical_step = ( # noqa mock_create_step_id_for_logical_step ) @@ -3421,7 +3423,7 @@ def execute_item(self, child_context, executable): assert branch_ctx.is_virtual is True assert branch_ctx._parent_id == map_op_id # noqa: SLF001 # The step-id prefix is the branch's own operation id (stable replay id). - assert branch_ctx._step_id_prefix != map_op_id # noqa: SLF001 + assert branch_ctx._operation_id_generator._step_id_prefix != map_op_id # noqa: SLF001 def test_nested_mode_stamps_branch_op_as_inner_op_parent_id(): @@ -3470,7 +3472,7 @@ def execute_item(self, child_context, executable): # its own operation id, not the grandparent. branch_ctx = executor.last_child_context assert branch_ctx.is_virtual is False - assert branch_ctx._parent_id == branch_ctx._step_id_prefix # noqa: SLF001 + assert branch_ctx._parent_id == branch_ctx._operation_id_generator._step_id_prefix # noqa: SLF001 assert branch_ctx._parent_id != map_op_id # noqa: SLF001 diff --git a/tests/context_test.py b/tests/context_test.py index 0e2cf0e2..a7561e7c 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -4,6 +4,7 @@ import json import random from itertools import islice +from unittest import mock from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -311,7 +312,9 @@ def test_create_callback_with_name_and_config(mock_executor_class): operation_ids = operation_id_sequence() [next(operation_ids) for _ in range(5)] # Skip 5 IDs expected_operation_id = next(operation_ids) # Get the 6th ID - [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(5) + ] # Set counter to 5 # noqa: SLF001 callback = context.create_callback(config=config) @@ -344,7 +347,9 @@ def test_create_callback_with_parent_id(mock_executor_class): operation_ids = operation_id_sequence("parent123") [next(operation_ids) for _ in range(2)] # Skip 2 IDs expected_operation_id = next(operation_ids) # Get the 3rd ID - [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(2) + ] # Set counter to 2 # noqa: SLF001 callback = context.create_callback() @@ -371,7 +376,9 @@ def test_create_callback_increments_counter(mock_executor_class): ) context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(10) + ] # Set counter to 10 # noqa: SLF001 callback1 = context.create_callback() callback2 = context.create_callback() @@ -384,7 +391,7 @@ def test_create_callback_increments_counter(mock_executor_class): assert callback1.operation_id == expected_id1 assert callback2.operation_id == expected_id2 - assert context._step_counter.get_current() == 12 # noqa: SLF001 + assert context._operation_id_generator._operation_counter.get_current() == 12 # noqa: SLF001 # endregion create_callback @@ -444,7 +451,9 @@ def test_step_with_name_and_config(mock_executor_class): config = StepConfig() context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(5) + ] # Set counter to 5 # noqa: SLF001 result = context.step(mock_callable, config=config) @@ -482,7 +491,9 @@ def test_step_with_parent_id(mock_executor_class): ) # Ensure _original_name doesn't exist context = create_test_context(state=mock_state, parent_id="parent123") - [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(2) + ] # Set counter to 2 # noqa: SLF001 context.step(mock_callable) @@ -519,7 +530,9 @@ def test_step_increments_counter(mock_executor_class): ) # Ensure _original_name doesn't exist context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(10) + ] # Set counter to 10 # noqa: SLF001 context.step(mock_callable) context.step(mock_callable) @@ -530,7 +543,7 @@ def test_step_increments_counter(mock_executor_class): expected_id1 = next(seq) # 11th expected_id2 = next(seq) # 12th - assert context._step_counter.get_current() == 12 # noqa: SLF001 + assert context._operation_id_generator._operation_counter.get_current() == 12 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) @@ -622,7 +635,9 @@ def test_invoke_with_name_and_config(mock_executor_class): config = InvokeConfig[str, str](timeout=Duration.from_seconds(30)) context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(5) + ] # Set counter to 5 # noqa: SLF001 result = context.invoke( "test_function", {"key": "value"}, name="named_invoke", config=config @@ -658,7 +673,9 @@ def test_invoke_with_parent_id(mock_executor_class): ) context = create_test_context(state=mock_state, parent_id="parent123") - [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(2) + ] # Set counter to 2 # noqa: SLF001 context.invoke("test_function", None) @@ -690,7 +707,9 @@ def test_invoke_increments_counter(mock_executor_class): ) context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(10) + ] # Set counter to 10 # noqa: SLF001 context.invoke("function1", "payload1") context.invoke("function2", "payload2") @@ -700,7 +719,7 @@ def test_invoke_increments_counter(mock_executor_class): expected_id1 = next(seq) expected_id2 = next(seq) - assert context._step_counter.get_current() == 12 # noqa: SLF001 + assert context._operation_id_generator._operation_counter.get_current() == 12 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) @@ -830,7 +849,9 @@ def test_wait_with_name(mock_executor_class): ) context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(5) + ] # Set counter to 5 # noqa: SLF001 context.wait(Duration.from_minutes(1), name="test_wait") @@ -859,7 +880,9 @@ def test_wait_with_parent_id(mock_executor_class): ) context = create_test_context(state=mock_state, parent_id="parent123") - [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(2) + ] # Set counter to 2 # noqa: SLF001 context.wait(Duration.from_seconds(45)) @@ -888,7 +911,9 @@ def test_wait_increments_counter(mock_executor_class): ) context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(10) + ] # Set counter to 10 # noqa: SLF001 context.wait(Duration.from_seconds(15)) context.wait(Duration.from_seconds(25)) @@ -898,7 +923,7 @@ def test_wait_increments_counter(mock_executor_class): expected_id1 = next(seq) expected_id2 = next(seq) - assert context._step_counter.get_current() == 12 # noqa: SLF001 + assert context._operation_id_generator._operation_counter.get_current() == 12 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) @@ -993,7 +1018,9 @@ def test_run_in_child_context_with_name_and_config(mock_handler): config = ChildConfig() context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(3)] # Set counter to 3 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(3) + ] # Set counter to 3 # noqa: SLF001 result = context.run_in_child_context(mock_callable, config=config) @@ -1027,7 +1054,9 @@ def test_run_in_child_context_with_parent_id(mock_executor_class): ) # Ensure Mock doesn't have _original_name context = create_test_context(state=mock_state, parent_id="parent456") - [context._create_step_id() for _ in range(1)] # Set counter to 1 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(1) + ] # Set counter to 1 # noqa: SLF001 context.run_in_child_context(mock_callable) @@ -1088,7 +1117,9 @@ def test_run_in_child_context_increments_counter(mock_executor_class): ) # Ensure _original_name doesn't exist context = create_test_context(state=mock_state) - [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 + [ + context._operation_id_generator.create_step_id() for _ in range(5) + ] # Set counter to 5 # noqa: SLF001 context.run_in_child_context(mock_callable) context.run_in_child_context(mock_callable) @@ -1098,7 +1129,7 @@ def test_run_in_child_context_increments_counter(mock_executor_class): expected_id1 = next(seq) expected_id2 = next(seq) - assert context._step_counter.get_current() == 7 # noqa: SLF001 + assert context._operation_id_generator._operation_counter.get_current() == 7 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" ] == OperationIdentifier(expected_id1, None, None) @@ -1932,6 +1963,7 @@ def test_execution_context_propagates_to_child_context(): assert child_context.execution_context.durable_execution_arn == parent_arn # Should be the same instance (not a copy) assert child_context.execution_context is parent_context.execution_context + assert child_context.is_replaying def test_from_lambda_context_creates_execution_context(): @@ -1959,6 +1991,7 @@ def test_execution_context_type(): context = create_test_context(state=mock_state) assert isinstance(context.execution_context, ExecutionContext) + assert context.is_replaying # endregion ExecutionContext tests @@ -1983,7 +2016,7 @@ def test_should_default_step_id_prefix_to_parent_id_when_not_specified(): ) assert ctx._parent_id == "parent-op-1" # noqa: SLF001 - assert ctx._step_id_prefix == "parent-op-1" # noqa: SLF001 + assert ctx._operation_id_generator._step_id_prefix == "parent-op-1" # noqa: SLF001 assert ctx.is_virtual is False @@ -2005,7 +2038,7 @@ def test_should_mark_context_virtual_when_parent_id_differs_from_step_prefix(): ) assert ctx._parent_id == "grandparent-op" # noqa: SLF001 - assert ctx._step_id_prefix == "branch-op" # noqa: SLF001 + assert ctx._operation_id_generator._step_id_prefix == "branch-op" # noqa: SLF001 assert ctx.is_virtual is True @@ -2032,7 +2065,10 @@ def test_should_use_step_id_prefix_when_generating_step_ids(): ) expected_prefixed = hashlib.blake2b(b"branch-op-1").hexdigest()[:64] - assert virtual._create_step_id_for_logical_step(1) == expected_prefixed # noqa: SLF001 + assert ( + virtual._operation_id_generator.create_step_id_for_logical_step(1, False) + == expected_prefixed + ) # noqa: SLF001 def test_should_use_parent_id_as_step_prefix_when_non_virtual(): @@ -2059,7 +2095,10 @@ def test_should_use_parent_id_as_step_prefix_when_non_virtual(): ) expected = hashlib.blake2b(b"parent-op-1").hexdigest()[:64] - assert non_virtual._create_step_id_for_logical_step(1) == expected # noqa: SLF001 + assert ( + non_virtual._operation_id_generator.create_step_id_for_logical_step(1, False) + == expected + ) # noqa: SLF001 assert non_virtual.is_virtual is False @@ -2074,7 +2113,7 @@ def test_should_create_non_virtual_child_when_is_virtual_false(): child = parent.create_child_context("child-op") assert child._parent_id == "child-op" # noqa: SLF001 - assert child._step_id_prefix == "child-op" # noqa: SLF001 + assert child._operation_id_generator._step_id_prefix == "child-op" # noqa: SLF001 assert child.is_virtual is False @@ -2089,7 +2128,7 @@ def test_should_create_virtual_child_that_propagates_grandparent_id(): child = parent.create_child_context("child-op", is_virtual=True) assert child._parent_id == "grandparent-op" # noqa: SLF001 - assert child._step_id_prefix == "child-op" # noqa: SLF001 + assert child._operation_id_generator._step_id_prefix == "child-op" # noqa: SLF001 assert child.is_virtual is True @@ -2109,11 +2148,14 @@ def test_should_create_virtual_child_with_none_parent_when_parent_is_root(): child = root_parent.create_child_context("child-op", is_virtual=True) assert child._parent_id is None # noqa: SLF001 - assert child._step_id_prefix == "child-op" # noqa: SLF001 + assert child._operation_id_generator._step_id_prefix == "child-op" # noqa: SLF001 assert child.is_virtual is True - expected = hashlib.blake2b(b"child-op-1").hexdigest()[:64] - assert child._create_step_id_for_logical_step(1) == expected # noqa: SLF001 + expected = hashlib.blake2b(b"child-op-v-1").hexdigest()[:64] + assert ( + child._operation_id_generator.create_step_id_for_logical_step(1, True) + == expected + ) # noqa: SLF001 def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): @@ -2140,8 +2182,9 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): # First virtual layer: outer parallel is FLAT, so its branch is virtual. outer_branch = outer.create_child_context("outer-branch-op", is_virtual=True) assert outer_branch._parent_id == "outer-parallel-op" # noqa: SLF001 - assert outer_branch._step_id_prefix == "outer-branch-op" # noqa: SLF001 + assert outer_branch._operation_id_generator._step_id_prefix == "outer-branch-op" # noqa: SLF001 assert outer_branch.is_virtual is True + assert outer_branch.is_replaying # Second virtual layer: an inner FLAT map inside the outer branch, # whose per-item branch is also virtual. @@ -2151,14 +2194,61 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): # inner operations would report to a logical layer that does not # appear in the execution history, breaking the hierarchy. assert inner_branch._parent_id == "outer-parallel-op" # noqa: SLF001 - assert inner_branch._step_id_prefix == "inner-branch-op" # noqa: SLF001 + assert inner_branch._operation_id_generator._step_id_prefix == "inner-branch-op" # noqa: SLF001 assert inner_branch.is_virtual is True + assert inner_branch.is_replaying # Step ids inside the inner branch still prefix on the inner branch's # own operation id; they must not leak the outer ancestor into the # step-id namespace. - expected = hashlib.blake2b(b"inner-branch-op-1").hexdigest()[:64] - assert inner_branch._create_step_id_for_logical_step(1) == expected # noqa: SLF001 + expected = hashlib.blake2b(b"inner-branch-op-v-1").hexdigest()[:64] + assert ( + inner_branch._operation_id_generator.create_step_id_for_logical_step(1, True) + == expected + ) # noqa: SLF001 + + +def test_context_created_with_new_status_when_check_result_returns_nonexistent(): + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + context = create_test_context(state=mock_state) + assert not context.is_replaying + + # op id for 1 + mock_state.get_checkpoint_result.assert_called_once_with( + "1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97" + ) + + +def test_transition_replay_status_when_check_result_returns_nonexistent(): + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + mock_checkpoint_result = Mock(spec=CheckpointedResult) + mock_state.get_checkpoint_result.return_value = mock_checkpoint_result + mock_checkpoint_result.is_existent.return_value = True + context = create_test_context(state=mock_state) + assert context.is_replaying + + context._track_replay() + assert context.is_replaying + + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + context._track_replay() + assert not context.is_replaying + + # op id for 1 + mock_state.get_checkpoint_result.assert_called_with( + "1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97" + ) # endregion Virtual-context identity tests diff --git a/tests/execution_test.py b/tests/execution_test.py index db13b5a9..343fa462 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -2690,8 +2690,9 @@ def _make_lambda_context(): def test_durable_execution_replays_when_paginated_state_has_prior_operations(): """Test paginated execution state starts in replay mode when prior operations exist.""" mock_client = Mock(spec=DurableServiceClient) + # step_operation with operation_id = hashed(1) step_operation = Operation( - operation_id="step1", + operation_id="1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) @@ -2704,7 +2705,7 @@ def test_durable_execution_replays_when_paginated_state_has_prior_operations(): @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: - return {"is_replaying": context.state.is_replaying()} + return {"is_replaying": context.is_replaying} result = test_handler(invocation_input, _make_lambda_context()) diff --git a/tests/logger_test.py b/tests/logger_test.py index b6017fa6..d758431e 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -4,6 +4,8 @@ from collections.abc import Mapping from unittest.mock import Mock +from aws_durable_execution_sdk_python import DurableContext +from aws_durable_execution_sdk_python.context import ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( Operation, @@ -84,6 +86,11 @@ def exception( operations={}, service_client=Mock(), ) +EXECUTION_CONTEXT = ExecutionContext("arn:aws:test") +DURABLE_CONTEXT = DurableContext( + state=EXECUTION_STATE, + execution_context=EXECUTION_CONTEXT, +) def test_powertools_logger_compatibility(): @@ -102,8 +109,8 @@ def accepts_logger_interface(logger: LoggerInterface) -> None: accepts_logger_interface(powertools_logger) # Test that our Logger can wrap the PowertoolsLoggerStub - log_info = LogInfo(EXECUTION_STATE) - wrapped_logger = Logger.from_log_info(powertools_logger, log_info) + log_info = LogInfo() + wrapped_logger = Logger.from_log_info(powertools_logger, log_info, DURABLE_CONTEXT) # Test all methods work wrapped_logger.debug("debug message") @@ -115,8 +122,7 @@ def accepts_logger_interface(logger: LoggerInterface) -> None: def test_log_info_creation(): """Test LogInfo creation with all parameters.""" - log_info = LogInfo(EXECUTION_STATE, "parent123", "operation123", "test_name", 5) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo("parent123", "operation123", "test_name", 5) assert log_info.parent_id == "parent123" assert log_info.operation_id == "operation123" assert log_info.name == "test_name" @@ -125,8 +131,7 @@ def test_log_info_creation(): def test_log_info_creation_minimal(): """Test LogInfo creation with minimal parameters.""" - log_info = LogInfo(EXECUTION_STATE) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo() assert log_info.parent_id is None assert log_info.operation_id is None assert log_info.name is None @@ -136,8 +141,7 @@ def test_log_info_creation_minimal(): def test_log_info_from_operation_identifier(): """Test LogInfo.from_operation_identifier.""" op_id = OperationIdentifier("op123", "parent456", "op_name") - log_info = LogInfo.from_operation_identifier(EXECUTION_STATE, op_id, 3) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo.from_operation_identifier(op_id, 3) assert log_info.parent_id == "parent456" assert log_info.operation_id == "op123" assert log_info.name == "op_name" @@ -147,8 +151,7 @@ def test_log_info_from_operation_identifier(): def test_log_info_from_operation_identifier_no_attempt(): """Test LogInfo.from_operation_identifier without attempt.""" op_id = OperationIdentifier("op123", "parent456", "op_name") - log_info = LogInfo.from_operation_identifier(EXECUTION_STATE, op_id) - assert log_info.execution_state.durable_execution_arn == "arn:aws:test" + log_info = LogInfo.from_operation_identifier(op_id) assert log_info.parent_id == "parent456" assert log_info.operation_id == "op123" assert log_info.name == "op_name" @@ -157,9 +160,8 @@ def test_log_info_from_operation_identifier_no_attempt(): def test_log_info_with_parent_id(): """Test LogInfo.with_parent_id.""" - original = LogInfo(EXECUTION_STATE, "old_parent", "op123", "test_name", 2) + original = LogInfo("old_parent", "op123", "test_name", 2) new_log_info = original.with_parent_id("new_parent") - assert new_log_info.execution_state.durable_execution_arn == "arn:aws:test" assert new_log_info.parent_id == "new_parent" assert new_log_info.operation_id == "op123" assert new_log_info.name == "test_name" @@ -169,8 +171,8 @@ def test_log_info_with_parent_id(): def test_logger_from_log_info_full(): """Test Logger.from_log_info with all LogInfo fields.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE, "parent123", "op123", "test_name", 5) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123", "op123", "test_name", 5) + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = { "executionArn": "arn:aws:test", @@ -188,20 +190,20 @@ def test_logger_from_log_info_partial_fields(): mock_logger = Mock() # Test with parent_id but no name or attempt - log_info = LogInfo(EXECUTION_STATE, "parent123") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test", "parentId": "parent123"} assert logger._default_extra == expected_extra # noqa: SLF001 # Test with name but no parent_id or attempt - log_info = LogInfo(EXECUTION_STATE, None, None, "test_name") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo(None, None, "test_name") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test", "operationName": "test_name"} assert logger._default_extra == expected_extra # noqa: SLF001 # Test with attempt but no parent_id or name - log_info = LogInfo(EXECUTION_STATE, None, None, None, 5) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo(None, None, None, 5) + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test", "attempt": 5} assert logger._default_extra == expected_extra # noqa: SLF001 @@ -209,8 +211,8 @@ def test_logger_from_log_info_partial_fields(): def test_logger_from_log_info_minimal(): """Test Logger.from_log_info with minimal LogInfo.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) expected_extra = {"executionArn": "arn:aws:test"} assert logger._default_extra == expected_extra # noqa: SLF001 @@ -219,8 +221,7 @@ def test_logger_from_log_info_minimal(): def test_logger_with_log_info(): """Test Logger.with_log_info.""" mock_logger = Mock() - original_info = LogInfo(EXECUTION_STATE, "parent1") - logger = Logger.from_log_info(mock_logger, original_info) + original_info = LogInfo("parent1") execution_state_new = ExecutionState( durable_execution_arn="arn:aws:new", @@ -228,7 +229,9 @@ def test_logger_with_log_info(): operations={}, service_client=Mock(), ) - new_info = LogInfo(execution_state_new, "parent2", "op123", "new_name") + durable_context = DurableContext(execution_state_new, EXECUTION_CONTEXT) + logger = Logger.from_log_info(mock_logger, original_info, durable_context) + new_info = LogInfo("parent2", "op123", "new_name") new_logger = logger.with_log_info(new_info) expected_extra = { @@ -244,16 +247,16 @@ def test_logger_with_log_info(): def test_logger_get_logger(): """Test Logger.get_logger.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) assert logger.get_logger() is mock_logger def test_logger_debug(): """Test Logger.debug method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE, "parent123") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.debug("test %s message", "arg1", extra={"custom": "value"}) @@ -270,8 +273,8 @@ def test_logger_debug(): def test_logger_info(): """Test Logger.info method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.info("info message") @@ -282,8 +285,8 @@ def test_logger_info(): def test_logger_warning(): """Test Logger.warning method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.warning("warning %s %s message", "arg1", "arg2") @@ -296,8 +299,8 @@ def test_logger_warning(): def test_logger_error(): """Test Logger.error method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.error("error message", extra={"error_code": 500}) @@ -308,8 +311,8 @@ def test_logger_error(): def test_logger_exception(): """Test Logger.exception method.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.exception("exception message") @@ -322,8 +325,8 @@ def test_logger_exception(): def test_logger_methods_with_none_extra(): """Test logger methods handle None extra parameter.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE) - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo() + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.debug("debug", extra=None) logger.info("info", extra=None) @@ -342,8 +345,8 @@ def test_logger_methods_with_none_extra(): def test_logger_extra_override(): """Test that custom extra overrides default extra.""" mock_logger = Mock() - log_info = LogInfo(EXECUTION_STATE, "parent123") - logger = Logger.from_log_info(mock_logger, log_info) + log_info = LogInfo("parent123") + logger = Logger.from_log_info(mock_logger, log_info, DURABLE_CONTEXT) logger.info("test", extra={"executionArn": "overridden", "newField": "value"}) @@ -357,8 +360,8 @@ def test_logger_extra_override(): def test_logger_without_mocked_logger(): """Test Logger methods without mocking the underlying logger.""" - log_info = LogInfo(EXECUTION_STATE, "parent123", "test_name", 5) - logger = Logger.from_log_info(logging.getLogger(), log_info) + log_info = LogInfo("parent123", "op1", "test_name", 5) + logger = Logger.from_log_info(logging.getLogger(), log_info, DURABLE_CONTEXT) logger.info("test", extra={"execution_arn": "overridden", "new_field": "value"}) logger.warning("test", extra={"execution_arn": "overridden", "new_field": "value"}) @@ -378,12 +381,13 @@ def test_logger_replay_no_logging(): service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) - log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) + durable_context = Mock(DurableContext) + durable_context.is_replaying = True + durable_context.state = replay_execution_state + log_info = LogInfo("parent123", "op1", "test_name", 5) mock_logger = Mock() - logger = Logger.from_log_info(mock_logger, log_info) + logger = Logger.from_log_info(mock_logger, log_info, durable_context) logger.info("logging info") - replay_execution_state.track_replay(operation_id="op1") - mock_logger.info.assert_not_called() @@ -405,14 +409,16 @@ def test_logger_replay_then_new_logging(): service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) - log_info = LogInfo(execution_state, "parent123", "test_name", 5) + durable_context = Mock(DurableContext) + durable_context.is_replaying = True + durable_context.state = execution_state + log_info = LogInfo("parent123", "op1", "test_name", 5) mock_logger = Mock() - logger = Logger.from_log_info(mock_logger, log_info) - execution_state.track_replay(operation_id="op1") + logger = Logger.from_log_info(mock_logger, log_info, durable_context) logger.info("logging info") mock_logger.info.assert_not_called() - execution_state.track_replay(operation_id="op2") + durable_context.is_replaying = False logger.info("logging info") mock_logger.info.assert_called_once() diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index c7a653f6..cb5ff6a3 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -2,6 +2,7 @@ import importlib import json +from collections import defaultdict from unittest.mock import Mock, patch import pytest @@ -20,13 +21,17 @@ MapConfig, NestingType, ) -from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext +from aws_durable_execution_sdk_python.context import ( + DurableContext, + ExecutionContext, + OperationIdGenerator, +) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation import child # PLC0415 from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler from aws_durable_execution_sdk_python.serdes import serialize -from aws_durable_execution_sdk_python.state import ExecutionState +from aws_durable_execution_sdk_python.state import ExecutionState, CheckpointedResult from tests.serdes_test import CustomStrSerDes @@ -306,7 +311,9 @@ def callable_func(ctx, item, idx, items): ) executor_context = Mock() - executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + lambda *args: "1" + ) # noqa SLF001 executor_context.create_child_context = lambda *args, **kwargs: Mock() with patch.object( @@ -357,7 +364,9 @@ def callable_func(ctx, item, idx, items): mock_from_items.return_value = mock_executor executor_context = Mock() - executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + lambda *args: "1" + ) # noqa SLF001 executor_context.create_child_context = lambda *args, **kwargs: Mock() class MockExecutionState: @@ -403,7 +412,9 @@ def callable_func(ctx, item, idx, items): return f"RESULT_{item.upper()}" executor_context = Mock() - executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + lambda *args, **kwargs: "1" + ) # noqa SLF001 executor_context.create_child_context = lambda *args, **kwargs: Mock() class MockExecutionState: @@ -442,7 +453,9 @@ def mock_summary_generator(result): config = MapConfig(summary_generator=mock_summary_generator) executor_context = Mock() - executor_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2"]) # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = Mock( + side_effect=["1", "2"] + ) # noqa SLF001 executor_context.create_child_context = Mock(return_value=Mock()) class MockExecutionState: @@ -468,8 +481,11 @@ def get_checkpoint_result(self, operation_id): assert executor_context.create_child_context.call_count == 2 # Verify that _create_step_id_for_logical_step was called twice with unique values - assert executor_context._create_step_id_for_logical_step.call_count == 2 # noqa SLF001 - calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 + assert ( + executor_context._operation_id_generator.create_step_id_for_logical_step.call_count + == 2 + ) # noqa SLF001 + calls = executor_context._operation_id_generator.create_step_id_for_logical_step.call_args_list # noqa SLF001 # Verify unique values were passed assert calls[0] != calls[1] @@ -500,7 +516,9 @@ def callable_func(ctx, item, idx, items): return f"result_{item}" executor_context = Mock() - executor_context._create_step_id_for_logical_step = Mock(return_value="1") # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = Mock( + return_value="1" + ) # noqa SLF001 executor_context.create_child_context = Mock(return_value=Mock()) # SLF001 class MockExecutionState: @@ -526,7 +544,10 @@ def get_checkpoint_result(self, operation_id): assert executor_context.create_child_context.call_count == 1 # Verify that _create_step_id_for_logical_step was called once - assert executor_context._create_step_id_for_logical_step.call_count == 1 # noqa SLF001 + assert ( + executor_context._operation_id_generator.create_step_id_for_logical_step.call_count + == 1 + ) # noqa SLF001 def test_map_executor_init_with_summary_generator(): @@ -574,7 +595,7 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_map") executor_context = Mock() - executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = Mock( # noqa: SLF001 side_effect=["1", "2", "3"] ) executor_context.create_child_context = Mock(return_value=Mock()) @@ -846,20 +867,20 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 else f"child-{i}" ) - with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + with patch.object( + OperationIdGenerator, "create_step_id_for_logical_step", create_id + ): context = create_test_context(state=mock_state) context.map( ["a", "b"], @@ -908,20 +929,20 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 else f"child-{i}" ) - with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + with patch.object( + OperationIdGenerator, "create_step_id_for_logical_step", create_id + ): context = create_test_context(state=mock_state) context.map( ["a", "b"], @@ -1008,7 +1029,7 @@ def get_checkpoint(op_id): context_map = {} - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) if ctx_id not in context_map: context_map[ctx_id] = [] @@ -1020,7 +1041,7 @@ def create_id(self, i): ) with patch.object( - DurableContext, "_create_step_id_for_logical_step", create_id + OperationIdGenerator, "create_step_id_for_logical_step", create_id ): context = create_test_context(state=mock_state) result = context.map(["a", "b"], lambda ctx, item, idx, items: item) @@ -1060,7 +1081,7 @@ def get_checkpoint(op_id): context_map = {} - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) if ctx_id not in context_map: context_map[ctx_id] = [] @@ -1072,7 +1093,7 @@ def create_id(self, i): ) with patch.object( - DurableContext, "_create_step_id_for_logical_step", create_id + OperationIdGenerator, "create_step_id_for_logical_step", create_id ): context = create_test_context(state=mock_state) result = context.map(["a", "b"], lambda ctx, item, idx, items: item) @@ -1116,7 +1137,7 @@ def get_checkpoint(op_id): context_map = {} - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) if ctx_id not in context_map: context_map[ctx_id] = [] @@ -1128,7 +1149,7 @@ def create_id(self, i): ) with patch.object( - DurableContext, "_create_step_id_for_logical_step", create_id + OperationIdGenerator, "create_step_id_for_logical_step", create_id ): context = create_test_context(state=mock_state) result = context.map( diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index cf7c7367..79afff39 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -2,6 +2,7 @@ import importlib import json +from collections import defaultdict from collections.abc import Mapping from typing import Any from unittest.mock import Mock, patch @@ -23,7 +24,11 @@ NestingType, ParallelConfig, ) -from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext +from aws_durable_execution_sdk_python.context import ( + DurableContext, + ExecutionContext, + OperationIdGenerator, +) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation import child @@ -272,7 +277,9 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() - executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + lambda *args: "1" + ) # noqa SLF001 executor_context.create_child_context = lambda *args, **kwargs: Mock() with patch.object(ParallelExecutor, "from_callables") as mock_from_callables: @@ -310,7 +317,9 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() - executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + lambda *args: "1" + ) # noqa SLF001 executor_context.create_child_context = lambda *args, **kwargs: Mock() with patch.object(ParallelExecutor, "from_callables") as mock_from_callables: @@ -409,7 +418,9 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() - executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = ( + lambda *args, **kwargs: "1" + ) # noqa SLF001 executor_context.create_child_context = lambda *args, **kwargs: Mock() result = parallel_handler( @@ -445,7 +456,9 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() - executor_context._create_step_id_for_logical_step = Mock(return_value="1") # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = Mock( + return_value="1" + ) # noqa SLF001 executor_context.create_child_context = Mock(return_value=Mock()) # Call parallel_handler @@ -457,7 +470,10 @@ def get_checkpoint_result(self, operation_id): assert executor_context.create_child_context.call_count == 1 # Verify that _create_step_id_for_logical_step was called once with unique value - assert executor_context._create_step_id_for_logical_step.call_count == 1 # noqa SLF001 + assert ( + executor_context._operation_id_generator.create_step_id_for_logical_step.call_count + == 1 + ) # noqa SLF001 def test_parallel_executor_from_callables_with_summary_generator(): @@ -499,7 +515,9 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() - executor_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2"]) # noqa SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = Mock( + side_effect=["1", "2"] + ) # noqa SLF001 executor_context.create_child_context = Mock(return_value=Mock()) # Call parallel_handler with None config (should use default) @@ -511,8 +529,11 @@ def get_checkpoint_result(self, operation_id): assert executor_context.create_child_context.call_count == 2 # Verify that _create_step_id_for_logical_step was called twice with unique values - assert executor_context._create_step_id_for_logical_step.call_count == 2 # noqa SLF001 - calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 + assert ( + executor_context._operation_id_generator.create_step_id_for_logical_step.call_count + == 2 + ) # noqa SLF001 + calls = executor_context._operation_id_generator.create_step_id_for_logical_step.call_args_list # noqa SLF001 # Verify unique values were passed assert calls[0] != calls[1] @@ -543,7 +564,7 @@ def get_checkpoint_result(self, operation_id): operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() - executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 + executor_context._operation_id_generator.create_step_id_for_logical_step = Mock( # noqa: SLF001 side_effect=["1", "2", "3"] ) executor_context.create_child_context = Mock(return_value=Mock()) @@ -811,20 +832,20 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 else f"child-{i}" ) - with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + with patch.object( + OperationIdGenerator, "create_step_id_for_logical_step", create_id + ): context = create_test_context(state=mock_state) context.parallel( [lambda ctx: "a", lambda ctx: "b"], @@ -872,20 +893,20 @@ def get_checkpoint(op_id): mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() - context_map = {} + context_map = defaultdict(set) - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) - if ctx_id not in context_map: - context_map[ctx_id] = [] - context_map[ctx_id].append(i) + context_map[ctx_id].add(i) return ( "parent" if len(context_map) == 1 and len(context_map[ctx_id]) == 1 else f"child-{i}" ) - with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + with patch.object( + OperationIdGenerator, "create_step_id_for_logical_step", create_id + ): context = create_test_context(state=mock_state) context.parallel( [lambda ctx: "a", lambda ctx: "b"], @@ -985,7 +1006,7 @@ def get_checkpoint(op_id): context_map = {} - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) if ctx_id not in context_map: context_map[ctx_id] = [] @@ -997,7 +1018,7 @@ def create_id(self, i): ) with patch.object( - DurableContext, "_create_step_id_for_logical_step", create_id + OperationIdGenerator, "create_step_id_for_logical_step", create_id ): context = create_test_context(state=mock_state) result = context.parallel([lambda ctx: "a", lambda ctx: "b"]) @@ -1036,7 +1057,7 @@ def get_checkpoint(op_id): context_map = {} - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) if ctx_id not in context_map: context_map[ctx_id] = [] @@ -1048,7 +1069,7 @@ def create_id(self, i): ) with patch.object( - DurableContext, "_create_step_id_for_logical_step", create_id + OperationIdGenerator, "create_step_id_for_logical_step", create_id ): context = create_test_context(state=mock_state) result = context.parallel([lambda ctx: "a", lambda ctx: "b"]) @@ -1092,7 +1113,7 @@ def get_checkpoint(op_id): context_map = {} - def create_id(self, i): + def create_id(self, i, is_virtual): ctx_id = id(self) if ctx_id not in context_map: context_map[ctx_id] = [] @@ -1104,7 +1125,7 @@ def create_id(self, i): ) with patch.object( - DurableContext, "_create_step_id_for_logical_step", create_id + OperationIdGenerator, "create_step_id_for_logical_step", create_id ): context = create_test_context(state=mock_state) result = context.parallel( diff --git a/tests/state_test.py b/tests/state_test.py index 0152ca6c..343ed42d 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -3385,60 +3385,15 @@ def test_create_checkpoint_sync_always_synchronous(): def test_state_replay_mode(): - operation1 = Operation( - operation_id="op1", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) - execution_state = ExecutionState( - durable_execution_arn="arn:aws:test", - initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, - service_client=Mock(), - replay_status=ReplayStatus.REPLAY, - ) - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op1") - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") - assert execution_state.is_replaying() is False - - -def test_state_replay_mode_with_timed_out(): - """Test that TIMED_OUT operations are treated as terminal states for replay tracking. - - This test verifies that when an operation has TIMED_OUT status, it is correctly - recognized as a completed/terminal state, allowing the replay status to transition - from REPLAY to NEW once all completed operations have been visited. - - Regression test for: https://github.com/aws/aws-durable-execution-sdk-python/issues/262 - """ - operation1 = Operation( - operation_id="op1", - operation_type=OperationType.STEP, - status=OperationStatus.TIMED_OUT, - ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, + operations={}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op1") - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") + execution_state.transition_replay_status() assert execution_state.is_replaying() is False diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 77611a34..ce90380b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -17,4 +17,4 @@ def operation_id_sequence(parent_id: str | None = None): ) while True: - yield context._create_step_id() # noqa: SLF001 + yield context._operation_id_generator.create_step_id() # noqa: SLF001