Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/restate/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
"""This module contains the restate exceptions"""

# pylint: disable=C0301
from typing import Optional

from datetime import timedelta


class TerminalError(Exception):
Expand All @@ -22,6 +25,23 @@ def __init__(self, message: str, status_code: int = 500) -> None:
self.status_code = status_code


class RetryableError(Exception):
"""
This exception is thrown to indicate that Restate should retry with an explicit delay.

Args:
message: The error message.
status_code: The HTTP status code to return for this error (default: 500).
retry_after: The delay after which Restate should retry the invocation.
"""

def __init__(self, message: str, status_code: int = 500, retry_after: Optional[timedelta] = None) -> None:
super().__init__(message)
self.message = message
self.status_code = status_code
self.retry_after = retry_after


class SdkInternalBaseException(BaseException):
"""This exception is internal, and you should not catch it.
If you need to distinguish with other exceptions, use is_internal_exception."""
Expand Down
24 changes: 23 additions & 1 deletion python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@
RunOptions,
P,
)
from restate.exceptions import TerminalError, SdkInternalBaseException, SdkInternalException, SuspendedException
from restate.exceptions import (
TerminalError,
SdkInternalBaseException,
SdkInternalException,
SuspendedException,
RetryableError,
)
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, DefaultSerde, Serde
from restate.server_types import ReceiveChannel, Send
Expand Down Expand Up @@ -404,6 +410,10 @@ async def enter(self):
restate_context_is_replaying.set(False)
self.vm.sys_write_output_failure(failure)
self.vm.sys_end()
except RetryableError as r:
stacktrace = "".join(traceback.format_exception(r))
restate_context_is_replaying.set(False)
self.vm.notify_error(r.message, stacktrace, r.retry_after)
# pylint: disable=W0718
except asyncio.CancelledError:
pass
Expand Down Expand Up @@ -674,6 +684,18 @@ async def create_run_coroutine(
except TerminalError as t:
failure = Failure(code=t.status_code, message=t.message)
self.vm.propose_run_completion_failure(handle, failure)
except RetryableError as r:
failure = Failure(code=r.status_code, message=r.message)
end = time.time()
attempt_duration = int((end - start) * 1000)
self.vm.propose_run_completion_transient_with_delay_override(
handle,
failure,
attempt_duration_ms=attempt_duration,
delay_override=r.retry_after,
max_retry_attempts_override=max_attempts,
max_retry_duration_override=max_duration,
)
except asyncio.CancelledError as e:
raise e from None
except SdkInternalBaseException as e:
Expand Down
31 changes: 29 additions & 2 deletions python/restate/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
"""
wrap the restate._internal.PyVM class
"""

# pylint: disable=E1101,R0917
# pylint: disable=too-many-arguments
# pylint: disable=too-few-public-methods
from typing import Optional
from datetime import timedelta

from dataclasses import dataclass
import typing
Expand Down Expand Up @@ -173,9 +176,11 @@ def notify_input_closed(self):
"""Notify the virtual machine that the input has been closed."""
self.vm.notify_input_closed()

def notify_error(self, error: str, stacktrace: str):
def notify_error(self, error: str, stacktrace: str, delay_override: Optional[timedelta] = None):
"""Notify the virtual machine of an error."""
self.vm.notify_error(error, stacktrace)
self.vm.notify_error(
error, stacktrace, int(delay_override.total_seconds() * 1000) if delay_override is not None else None
)

def take_output(self) -> typing.Optional[bytes]:
"""Take the output from the virtual machine."""
Expand Down Expand Up @@ -444,6 +449,28 @@ def propose_run_completion_transient(
)
self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config)

def propose_run_completion_transient_with_delay_override(
self,
handle: int,
failure: Failure,
attempt_duration_ms: int,
delay_override: timedelta | None,
max_retry_attempts_override: int | None,
max_retry_duration_override: timedelta | None,
):
"""
Exit a side effect with a transient Error and override the retry policy with explicit parameters.
"""
py_failure = PyFailure(failure.code, failure.message, failure.stacktrace)
self.vm.propose_run_completion_failure_transient_with_delay_override(
handle,
py_failure,
attempt_duration_ms,
int(delay_override.total_seconds() * 1000) if delay_override else None,
max_retry_attempts_override,
int(max_retry_duration_override.total_seconds() * 1000) if max_retry_duration_override else None,
)

def sys_end(self):
"""
This method is responsible for ending the system.
Expand Down
38 changes: 36 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,15 @@ impl PyVM {
self_.vm.notify_input_closed();
}

#[pyo3(signature = (error, stacktrace=None))]
fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option<String>) {
#[pyo3(signature = (error, stacktrace=None, delay_override_ms=None))]
fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option<String>, delay_override_ms: Option<u64>) {
let mut error = Error::new(restate_sdk_shared_core::error::codes::INTERNAL, error);
if let Some(desc) = stacktrace {
error = error.with_stacktrace(desc);
}
if let Some(delay) = delay_override_ms {
error = error.with_next_retry_delay_override(Duration::from_millis(delay));
}
CoreVM::notify_error(&mut self_.vm, error, None);
}

Expand Down Expand Up @@ -721,6 +724,37 @@ impl PyVM {
.map_err(Into::into)
}

fn propose_run_completion_failure_transient_with_delay_override(
mut self_: PyRefMut<'_, Self>,
handle: PyNotificationHandle,
value: PyFailure,
attempt_duration: u64,
delay_override_ms: Option<u64>,
max_retry_attempts_override: Option<u32>,
max_retry_duration_override_ms: Option<u64>,
) -> Result<(), PyVMError> {
let retry_policy = if delay_override_ms.is_some() || max_retry_attempts_override.is_some() || max_retry_duration_override_ms.is_some() {
RetryPolicy::FixedDelay {
interval: delay_override_ms.map(Duration::from_millis),
max_attempts: max_retry_attempts_override,
max_duration: max_retry_duration_override_ms.map(Duration::from_millis),
}
} else {
RetryPolicy::Infinite
};
self_
.vm
.propose_run_completion(
handle.into(),
RunExitResult::RetryableFailure {
attempt_duration: Duration::from_millis(attempt_duration),
error: value.into(),
},
retry_policy,
)
.map_err(Into::into)
}

fn sys_write_output_success(
mut self_: PyRefMut<'_, Self>,
buffer: &Bound<'_, PyBytes>,
Expand Down
34 changes: 34 additions & 0 deletions tests/servercontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
from datetime import timedelta
from restate.exceptions import RetryableError

from contextlib import asynccontextmanager
import restate
from restate import (
Context,
HttpError,
InvocationRetryPolicy,
RunOptions,
Service,
TerminalError,
Expand Down Expand Up @@ -78,6 +82,36 @@ async def greet(ctx: Context, name: str) -> str:
await client.service_call(greet, arg="bob")


async def test_retryable_exception():
greeter = Service("greeter")
attempts = 0

@greeter.handler(
invocation_retry_policy=InvocationRetryPolicy(
max_attempts=3,
# Something really long to trigger a test timeout.
# Default httpx client timeout is 5 seconds.
initial_interval=timedelta(hours=1),
),
)
async def greet(ctx: Context, name: str) -> str:
nonlocal attempts
print(f"Attempt {attempts}")
try:
if attempts == 0:
raise RetryableError("Simulated retryable error", retry_after=timedelta(milliseconds=100))
else:
raise TerminalError("Simulated terminal error")
finally:
attempts += 1

async with simple_harness(greeter) as client:
with pytest.raises(HttpError): # Should be some sort of client error (not a timeout).
await client.service_call(greet, arg="bob")

assert attempts == 2


async def test_promise_default_serde():
workflow = Workflow("test_workflow")

Expand Down
Loading