Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/restate/ext/adk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .session import RestateSessionService
from .plugin import RestatePlugin
from .summarizer import RestateEventSummarizer
from restate import ObjectContext, Context
from restate.extensions import current_context

Expand All @@ -35,6 +36,7 @@ def restate_context() -> Context:
__all__ = [
"RestateSessionService",
"RestatePlugin",
"RestateEventSummarizer",
"restate_object_context",
"restate_context",
]
5 changes: 4 additions & 1 deletion python/restate/ext/adk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event
# For now, we also store temp state
event = self._trim_temp_delta_state(event)
self._update_session_state(session, event)
session.events.append(event)
# Compaction runs after after_run_callback (which flushes the session),
# so compaction events must be flushed explicitly here.
if event.actions and event.actions.compaction:
await self.flush_session_state(session)
return event

async def flush_session_state(self, session: Session):
Expand Down
95 changes: 95 additions & 0 deletions python/restate/ext/adk/summarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
Restate-aware event summarizer for ADK compaction.

Wraps the LlmEventSummarizer so the summarization call is journaled
through ctx.run, making it deterministic on replay.
"""

import restate

from datetime import timedelta
from typing import Optional

from google.adk.apps.base_events_summarizer import BaseEventsSummarizer
from google.adk.apps.llm_event_summarizer import LlmEventSummarizer
from google.adk.events.event import Event
from google.adk.models.base_llm import BaseLlm

from restate.extensions import current_context


class RestateEventSummarizer(BaseEventsSummarizer):
"""Event summarizer that journals the summarization call through Restate ctx.run.

Wraps any BaseEventsSummarizer in ctx.run_typed so the result is persisted
in the Restate journal and replayed deterministically.

Use the factory methods to create instances:
- ``RestateEventSummarizer.from_llm(llm)`` for the default LlmEventSummarizer
- ``RestateEventSummarizer.from_summarizer(summarizer)`` for a custom summarizer
"""

def __init__(
self,
inner: BaseEventsSummarizer,
max_retries: int = 10,
):
self._inner = inner
self._max_retries = max_retries

@staticmethod
def from_llm(
llm: BaseLlm,
prompt_template: Optional[str] = None,
max_retries: int = 10,
) -> "RestateEventSummarizer":
"""Create a RestateEventSummarizer using the default LlmEventSummarizer."""
return RestateEventSummarizer(
LlmEventSummarizer(llm=llm, prompt_template=prompt_template),
max_retries=max_retries,
)

@staticmethod
def from_summarizer(
summarizer: BaseEventsSummarizer,
max_retries: int = 10,
) -> "RestateEventSummarizer":
"""Create a RestateEventSummarizer wrapping a custom summarizer."""
return RestateEventSummarizer(summarizer, max_retries=max_retries)

async def maybe_summarize_events(
self, *, events: list[Event]
) -> Optional[Event]:
if not events:
return None

ctx = current_context()
if ctx is None:
raise RuntimeError(
"No Restate context found. "
"RestateEventSummarizer must be used from within a Restate handler."
)

inner = self._inner

async def call_inner() -> Optional[Event]:
return await inner.maybe_summarize_events(events=events)

return await ctx.run_typed(
"compaction LLM call",
call_inner,
restate.RunOptions(
max_attempts=self._max_retries,
initial_retry_interval=timedelta(seconds=1),
),
)
Loading