From 9cea2347a804ed7ed6ff501976c92e4e5e721087 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Tue, 17 Mar 2026 13:43:12 +0000 Subject: [PATCH] Introduce RetryableError Allow users to customize retry timing when needed, e.g., when receiving a `Retry-After` header. cf. https://github.com/restatedev/sdk-typescript/pull/569 --- python/restate/exceptions.py | 20 +++++++++++++++++ python/restate/server_context.py | 24 +++++++++++++++++++- python/restate/vm.py | 31 ++++++++++++++++++++++++-- src/lib.rs | 38 ++++++++++++++++++++++++++++++-- tests/servercontext.py | 34 ++++++++++++++++++++++++++++ 5 files changed, 142 insertions(+), 5 deletions(-) diff --git a/python/restate/exceptions.py b/python/restate/exceptions.py index 22b8498..3884b10 100644 --- a/python/restate/exceptions.py +++ b/python/restate/exceptions.py @@ -11,6 +11,9 @@ """This module contains the restate exceptions""" # pylint: disable=C0301 +from typing import Optional + +from datetime import timedelta class TerminalError(Exception): @@ -22,6 +25,23 @@ def __init__(self, message: str, status_code: int = 500) -> None: self.status_code = status_code +class RetryableError(Exception): + """ + This exception is thrown to indicate that Restate should retry with an explicit delay. + + Args: + message: The error message. + status_code: The HTTP status code to return for this error (default: 500). + retry_after: The delay after which Restate should retry the invocation. + """ + + def __init__(self, message: str, status_code: int = 500, retry_after: Optional[timedelta] = None) -> None: + super().__init__(message) + self.message = message + self.status_code = status_code + self.retry_after = retry_after + + class SdkInternalBaseException(BaseException): """This exception is internal, and you should not catch it. If you need to distinguish with other exceptions, use is_internal_exception.""" diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 59823d8..caa46e2 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -46,7 +46,13 @@ RunOptions, P, ) -from restate.exceptions import TerminalError, SdkInternalBaseException, SdkInternalException, SuspendedException +from restate.exceptions import ( + TerminalError, + SdkInternalBaseException, + SdkInternalException, + SuspendedException, + RetryableError, +) from restate.handler import Handler, handler_from_callable, invoke_handler from restate.serde import BytesSerde, DefaultSerde, Serde from restate.server_types import ReceiveChannel, Send @@ -404,6 +410,10 @@ async def enter(self): restate_context_is_replaying.set(False) self.vm.sys_write_output_failure(failure) self.vm.sys_end() + except RetryableError as r: + stacktrace = "".join(traceback.format_exception(r)) + restate_context_is_replaying.set(False) + self.vm.notify_error(r.message, stacktrace, r.retry_after) # pylint: disable=W0718 except asyncio.CancelledError: pass @@ -674,6 +684,18 @@ async def create_run_coroutine( except TerminalError as t: failure = Failure(code=t.status_code, message=t.message) self.vm.propose_run_completion_failure(handle, failure) + except RetryableError as r: + failure = Failure(code=r.status_code, message=r.message) + end = time.time() + attempt_duration = int((end - start) * 1000) + self.vm.propose_run_completion_transient_with_delay_override( + handle, + failure, + attempt_duration_ms=attempt_duration, + delay_override=r.retry_after, + max_retry_attempts_override=max_attempts, + max_retry_duration_override=max_duration, + ) except asyncio.CancelledError as e: raise e from None except SdkInternalBaseException as e: diff --git a/python/restate/vm.py b/python/restate/vm.py index 908ed13..3250a0a 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -11,9 +11,12 @@ """ wrap the restate._internal.PyVM class """ + # pylint: disable=E1101,R0917 # pylint: disable=too-many-arguments # pylint: disable=too-few-public-methods +from typing import Optional +from datetime import timedelta from dataclasses import dataclass import typing @@ -173,9 +176,11 @@ def notify_input_closed(self): """Notify the virtual machine that the input has been closed.""" self.vm.notify_input_closed() - def notify_error(self, error: str, stacktrace: str): + def notify_error(self, error: str, stacktrace: str, delay_override: Optional[timedelta] = None): """Notify the virtual machine of an error.""" - self.vm.notify_error(error, stacktrace) + self.vm.notify_error( + error, stacktrace, int(delay_override.total_seconds() * 1000) if delay_override is not None else None + ) def take_output(self) -> typing.Optional[bytes]: """Take the output from the virtual machine.""" @@ -444,6 +449,28 @@ def propose_run_completion_transient( ) self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config) + def propose_run_completion_transient_with_delay_override( + self, + handle: int, + failure: Failure, + attempt_duration_ms: int, + delay_override: timedelta | None, + max_retry_attempts_override: int | None, + max_retry_duration_override: timedelta | None, + ): + """ + Exit a side effect with a transient Error and override the retry policy with explicit parameters. + """ + py_failure = PyFailure(failure.code, failure.message, failure.stacktrace) + self.vm.propose_run_completion_failure_transient_with_delay_override( + handle, + py_failure, + attempt_duration_ms, + int(delay_override.total_seconds() * 1000) if delay_override else None, + max_retry_attempts_override, + int(max_retry_duration_override.total_seconds() * 1000) if max_retry_duration_override else None, + ) + def sys_end(self): """ This method is responsible for ending the system. diff --git a/src/lib.rs b/src/lib.rs index be3dda1..78b4dd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -335,12 +335,15 @@ impl PyVM { self_.vm.notify_input_closed(); } - #[pyo3(signature = (error, stacktrace=None))] - fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option) { + #[pyo3(signature = (error, stacktrace=None, delay_override_ms=None))] + fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option, delay_override_ms: Option) { let mut error = Error::new(restate_sdk_shared_core::error::codes::INTERNAL, error); if let Some(desc) = stacktrace { error = error.with_stacktrace(desc); } + if let Some(delay) = delay_override_ms { + error = error.with_next_retry_delay_override(Duration::from_millis(delay)); + } CoreVM::notify_error(&mut self_.vm, error, None); } @@ -721,6 +724,37 @@ impl PyVM { .map_err(Into::into) } + fn propose_run_completion_failure_transient_with_delay_override( + mut self_: PyRefMut<'_, Self>, + handle: PyNotificationHandle, + value: PyFailure, + attempt_duration: u64, + delay_override_ms: Option, + max_retry_attempts_override: Option, + max_retry_duration_override_ms: Option, + ) -> Result<(), PyVMError> { + let retry_policy = if delay_override_ms.is_some() || max_retry_attempts_override.is_some() || max_retry_duration_override_ms.is_some() { + RetryPolicy::FixedDelay { + interval: delay_override_ms.map(Duration::from_millis), + max_attempts: max_retry_attempts_override, + max_duration: max_retry_duration_override_ms.map(Duration::from_millis), + } + } else { + RetryPolicy::Infinite + }; + self_ + .vm + .propose_run_completion( + handle.into(), + RunExitResult::RetryableFailure { + attempt_duration: Duration::from_millis(attempt_duration), + error: value.into(), + }, + retry_policy, + ) + .map_err(Into::into) + } + fn sys_write_output_success( mut self_: PyRefMut<'_, Self>, buffer: &Bound<'_, PyBytes>, diff --git a/tests/servercontext.py b/tests/servercontext.py index 308adf0..980d9e7 100644 --- a/tests/servercontext.py +++ b/tests/servercontext.py @@ -8,11 +8,15 @@ # directory of this repository or package, or at # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # +from datetime import timedelta +from restate.exceptions import RetryableError from contextlib import asynccontextmanager import restate from restate import ( Context, + HttpError, + InvocationRetryPolicy, RunOptions, Service, TerminalError, @@ -78,6 +82,36 @@ async def greet(ctx: Context, name: str) -> str: await client.service_call(greet, arg="bob") +async def test_retryable_exception(): + greeter = Service("greeter") + attempts = 0 + + @greeter.handler( + invocation_retry_policy=InvocationRetryPolicy( + max_attempts=3, + # Something really long to trigger a test timeout. + # Default httpx client timeout is 5 seconds. + initial_interval=timedelta(hours=1), + ), + ) + async def greet(ctx: Context, name: str) -> str: + nonlocal attempts + print(f"Attempt {attempts}") + try: + if attempts == 0: + raise RetryableError("Simulated retryable error", retry_after=timedelta(milliseconds=100)) + else: + raise TerminalError("Simulated terminal error") + finally: + attempts += 1 + + async with simple_harness(greeter) as client: + with pytest.raises(HttpError): # Should be some sort of client error (not a timeout). + await client.service_call(greet, arg="bob") + + assert attempts == 2 + + async def test_promise_default_serde(): workflow = Workflow("test_workflow")