diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContext.java b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContext.java index 274b31a79..433278f04 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContext.java @@ -44,8 +44,12 @@ public interface BaseContext extends AutoCloseable { /** Gets the context name for this context. Null for root context. */ String getContextName(); - /** Returns whether this context is currently in replay mode. */ - boolean isReplaying(); + /** + * Returns whether this context is currently replaying based on per-context tracking. Checks whether the next + * operation in this specific context already exists in checkpoint storage, providing accurate replay status even + * when multiple contexts run concurrently. + */ + boolean isReplayingContext(); /** Closes this context. */ void close(); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java index 9920366f4..ef357871f 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java @@ -15,8 +15,6 @@ public abstract class BaseContextImpl implements AutoCloseable, BaseContext { private final String contextName; private final ThreadType threadType; - private boolean isReplaying; - /** * Creates a new BaseContext instance. * @@ -39,7 +37,6 @@ protected BaseContextImpl( this.lambdaContext = lambdaContext; this.contextId = contextId; this.contextName = contextName; - this.isReplaying = executionManager.hasOperationsForContext(contextId); this.threadType = threadType; } @@ -97,16 +94,12 @@ public ExecutionManager getExecutionManager() { return executionManager; } - /** Returns whether this context is currently in replay mode. */ - @Override - public boolean isReplaying() { - return isReplaying; - } - /** - * Transitions this context from replay to execution mode. Called when the first un-cached operation is encountered. + * Returns whether this context is currently in replay mode. The default implementation returns false. Subclasses + * that track per-context replay status (like DurableContextImpl) override this. */ - public void setExecutionMode() { - this.isReplaying = false; + @Override + public boolean isReplayingContext() { + return false; } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index c08b6317b..b96884a74 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -63,6 +63,7 @@ public class DurableContextImpl extends BaseContextImpl implements DurableContex private final DurableContextImpl parentContext; private final boolean isVirtual; private volatile DurableLogger logger; + private boolean replayMode; /** Shared initialization — sets all fields. */ private DurableContextImpl( @@ -77,6 +78,8 @@ private DurableContextImpl( operationIdGenerator = new OperationIdGenerator(contextId); this.parentContext = parentContext; this.isVirtual = isVirtual; + // Initialize replay mode by checking if the next operation (first in this context) exists in storage + this.replayMode = executionManager.hasOperation(operationIdGenerator.peekNextOperationId()); } /** @@ -142,6 +145,7 @@ public DurableFuture stepAsync( config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); } var operationId = nextOperationId(); + updateReplayStatus(); // Create and start step operation with TypeToken var operation = new StepOperation<>( @@ -158,6 +162,7 @@ public DurableFuture waitAsync(String name, Duration duration) { ParameterValidator.validateOperationName(name); var operationId = nextOperationId(); + updateReplayStatus(); // Create and start wait operation var operation = @@ -183,6 +188,7 @@ public DurableFuture invokeAsync( .build(); } var operationId = nextOperationId(); + updateReplayStatus(); // Create and start invoke operation var operation = new InvokeOperation<>( @@ -204,6 +210,7 @@ public DurableCallbackFuture createCallback(String name, TypeToken res config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); } var operationId = nextOperationId(); + updateReplayStatus(); var operation = new CallbackOperation<>( OperationIdentifier.of(operationId, name, OperationSubType.CALLBACK), resultType, config, this); @@ -245,6 +252,7 @@ private DurableFuture runInChildContextAsync( } var operationId = nextOperationId(); + updateReplayStatus(); var operation = new ChildContextOperation<>( OperationIdentifier.of(operationId, name, subType), func, resultType, config, this); @@ -270,6 +278,7 @@ public DurableFuture> mapAsync( // Convert to List for deterministic index-based access var itemList = List.copyOf(items); var operationId = nextOperationId(); + updateReplayStatus(); var operation = new MapOperation<>( OperationIdentifier.of(operationId, name, OperationSubType.MAP), @@ -286,6 +295,7 @@ public DurableFuture> mapAsync( public ParallelDurableFuture parallel(String name, ParallelConfig config) { Objects.requireNonNull(config, "config cannot be null"); var operationId = nextOperationId(); + updateReplayStatus(); var parallelOp = new ParallelOperation( OperationIdentifier.of(operationId, name, OperationSubType.PARALLEL), @@ -357,6 +367,7 @@ public DurableFuture waitForConditionAsync( config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); } var operationId = nextOperationId(); + updateReplayStatus(); var operation = new WaitForConditionOperation<>( OperationIdentifier.of(operationId, name, OperationSubType.WAIT_FOR_CONDITION), @@ -454,6 +465,28 @@ public void close() { } } + /** + * Returns whether this context is currently in replay mode based on per-context tracking. A context is replaying + * when its next operation already exists in checkpoint storage. + */ + @Override + public boolean isReplayingContext() { + return replayMode; + } + + /** + * Checks if the next operation exists in checkpoint storage and transitions out of replay mode if it does not. This + * is called before each operation to maintain accurate per-context replay status. + */ + public void updateReplayStatus() { + if (!replayMode) { + return; + } + if (!getExecutionManager().hasOperation(operationIdGenerator.peekNextOperationId())) { + replayMode = false; + } + } + /** * Get the next operationId. Returns a globally unique operation ID by hashing a sequential operation counter. For * root contexts, the counter value is hashed directly (e.g. "1", "2", "3"). For child contexts, the values are diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java index 7e6d514d9..469f968ed 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java @@ -178,6 +178,16 @@ public boolean hasOperationsForContext(String parentId) { return operationStorage.values().stream().anyMatch(op -> Objects.equals(op.parentId(), parentId)); } + /** + * Checks whether an operation with the given ID exists in checkpoint storage. + * + * @param operationId the operation ID to check + * @return true if the operation exists + */ + public boolean hasOperation(String operationId) { + return operationStorage.containsKey(operationId); + } + // ===== Thread Coordination ===== /** Sets the current thread's ThreadContext (threadId and threadType). Called when a user thread is started. */ public void setCurrentThreadContext(ThreadContext threadContext) { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/OperationIdGenerator.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/OperationIdGenerator.java index 08ea883db..274d0c8e2 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/OperationIdGenerator.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/OperationIdGenerator.java @@ -45,4 +45,13 @@ public String nextOperationId() { var counter = String.valueOf(operationCounter.incrementAndGet()); return hashOperationId(operationIdPrefix + counter); } + + /** + * Returns the operation ID that would be generated by the next call to {@link #nextOperationId()} without + * incrementing the counter. Used to check whether the next operation already exists in checkpoint storage. + */ + public String peekNextOperationId() { + var counter = String.valueOf(operationCounter.get() + 1); + return hashOperationId(operationIdPrefix + counter); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/logging/DurableLogger.java b/sdk/src/main/java/software/amazon/lambda/durable/logging/DurableLogger.java index 735e42867..83265c7aa 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/logging/DurableLogger.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/logging/DurableLogger.java @@ -93,8 +93,7 @@ public void error(String message, Throwable t) { } private boolean shouldSuppress() { - return context.getDurableConfig().getLoggerConfig().suppressReplayLogs() - && context.getExecutionManager().isReplaying(); + return context.getDurableConfig().getLoggerConfig().suppressReplayLogs() && context.isReplayingContext(); } private void log(Runnable logAction) { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index b3dd55b95..ce3b5c554 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -134,9 +134,6 @@ public void execute() { } replay(existing); } else { - if (durableContext.isReplaying()) { - this.durableContext.setExecutionMode(); - } start(); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/logging/DurableLoggerTest.java b/sdk/src/test/java/software/amazon/lambda/durable/logging/DurableLoggerTest.java index ad13ac825..cafe646fb 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/logging/DurableLoggerTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/logging/DurableLoggerTest.java @@ -13,6 +13,7 @@ import software.amazon.lambda.durable.TestContext; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; +import software.amazon.lambda.durable.execution.OperationIdGenerator; class DurableLoggerTest { private static final String EXECUTION_NAME = "exec-123"; @@ -42,7 +43,7 @@ void setUp() { } private DurableLogger createLogger(Mode mode, Suppression suppression) { - when(mockExecutionManager.isReplaying()).thenReturn(mode == Mode.REPLAYING); + when(mockExecutionManager.hasOperation(anyString())).thenReturn(mode == Mode.REPLAYING); return new DurableLogger(mockLogger, createDurableContext(REQUEST_ID, suppression)); } @@ -104,7 +105,7 @@ void setsExecutionMdcInConstructor() { void setStepThreadPropertiesSetsMdc() { try (MockedStatic mdcMock = mockStatic(MDC.class)) { mdcMock.clearInvocations(); - when(mockExecutionManager.isReplaying()).thenReturn(false); + when(mockExecutionManager.hasOperation(anyString())).thenReturn(false); var logger = new DurableLogger( mockLogger, createDurableContext(REQUEST_ID, Suppression.ENABLED) @@ -130,13 +131,18 @@ void clearThreadPropertiesRemovesMdc() { @Test void replayModeTransitionAllowsSubsequentLogs() { - when(mockExecutionManager.isReplaying()).thenReturn(true, false); - var logger = new DurableLogger(mockLogger, createDurableContext(REQUEST_ID, Suppression.ENABLED)); + when(mockExecutionManager.hasOperation(anyString())).thenReturn(true); + var durableContext = createDurableContext(REQUEST_ID, Suppression.ENABLED); + var logger = new DurableLogger(mockLogger, durableContext); // During replay - suppressed logger.info("suppressed"); verify(mockLogger, never()).info(anyString(), any(Object[].class)); + // Simulate next operation not existing in storage — triggers transition out of replay + when(mockExecutionManager.hasOperation(anyString())).thenReturn(false); + durableContext.updateReplayStatus(); + // After transition to execution mode - logged logger.info("logged after transition"); verify(mockLogger).info(eq("logged after transition"), any(Object[].class)); @@ -163,10 +169,44 @@ void allLogLevelsDelegateCorrectly() { verify(mockLogger).error("error with exception", exception); } + @Test + void concurrentContextsHaveIndependentReplayState() { + var rootNextOp = OperationIdGenerator.hashOperationId("1"); + var childANextOp = OperationIdGenerator.hashOperationId("child-a-1"); + var childBNextOp = OperationIdGenerator.hashOperationId("child-b-1"); + + when(mockExecutionManager.hasOperation(rootNextOp)).thenReturn(true); + when(mockExecutionManager.hasOperation(childANextOp)).thenReturn(false); + when(mockExecutionManager.hasOperation(childBNextOp)).thenReturn(true); + + var rootContext = createDurableContext(REQUEST_ID, Suppression.ENABLED); + var childA = rootContext.createChildContext("child-a", "branch-a", false); + var childB = rootContext.createChildContext("child-b", "branch-b", false); + + var loggerForA = mock(Logger.class); + var loggerForB = mock(Logger.class); + var durableLoggerA = new DurableLogger(loggerForA, childA); + var durableLoggerB = new DurableLogger(loggerForB, childB); + + // Child A is in execution mode — logs should pass through + durableLoggerA.info("from branch A"); + verify(loggerForA).info(eq("from branch A"), any(Object[].class)); + + // Child B is still replaying — logs should be suppressed + durableLoggerB.info("from branch B"); + verify(loggerForB, never()).info(anyString(), any(Object[].class)); + + // After child B transitions, its logs should pass through + when(mockExecutionManager.hasOperation(childBNextOp)).thenReturn(false); + childB.updateReplayStatus(); + durableLoggerB.info("branch B after transition"); + verify(loggerForB).info(eq("branch B after transition"), any(Object[].class)); + } + @Test void handlesNullRequestId() { try (MockedStatic mdcMock = mockStatic(MDC.class)) { - when(mockExecutionManager.isReplaying()).thenReturn(false); + when(mockExecutionManager.hasOperation(anyString())).thenReturn(false); new DurableLogger(mockLogger, createDurableContext(null, Suppression.DISABLED)); mdcMock.verify(() -> MDC.put(DurableLogger.MDC_EXECUTION_ARN, EXECUTION_ARN));