From a7aef086a736faa7359ecbc44ae64b1ac8ab1376 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Mon, 4 May 2026 09:57:30 +0200 Subject: [PATCH] Move limit middlewares from splunklib.ai.hooks to splunklib.ai.limits Also changed the TokenLimitExceededException to accept an int, instead of a float. --- splunklib/ai/README.md | 2 +- splunklib/ai/base_agent.py | 2 +- splunklib/ai/hooks.py | 159 --------------- splunklib/ai/limits.py | 184 ++++++++++++++++++ tests/integration/ai/test_hooks.py | 10 +- .../integration/ai/test_structured_output.py | 2 +- tests/unit/ai/test_default_limits.py | 2 +- 7 files changed, 194 insertions(+), 167 deletions(-) create mode 100644 splunklib/ai/limits.py diff --git a/splunklib/ai/README.md b/splunklib/ai/README.md index 3e6cc909..651f2586 100644 --- a/splunklib/ai/README.md +++ b/splunklib/ai/README.md @@ -958,7 +958,7 @@ class. The default for that limit is suppressed automatically - the other defaul remain active: ```py -from splunklib.ai.hooks import ( +from splunklib.ai.limits import ( TokenLimitMiddleware, StepLimitMiddleware, TimeoutLimitMiddleware, diff --git a/splunklib/ai/base_agent.py b/splunklib/ai/base_agent.py index 76731973..3e9de535 100644 --- a/splunklib/ai/base_agent.py +++ b/splunklib/ai/base_agent.py @@ -21,7 +21,7 @@ from pydantic import BaseModel from splunklib.ai.conversation_store import ConversationStore -from splunklib.ai.hooks import ( +from splunklib.ai.limits import ( DEFAULT_STEP_LIMIT, DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT, DEFAULT_TIMEOUT_SECONDS, diff --git a/splunklib/ai/hooks.py b/splunklib/ai/hooks.py index a2f78158..f21849b4 100644 --- a/splunklib/ai/hooks.py +++ b/splunklib/ai/hooks.py @@ -1,6 +1,5 @@ import inspect from collections.abc import Awaitable, Callable -from time import monotonic from typing import Any, override from splunklib.ai.messages import AgentResponse @@ -12,44 +11,6 @@ ModelRequest, ModelResponse, ) -from splunklib.ai.structured_output import StructuredOutputGenerationException - -DEFAULT_TIMEOUT_SECONDS: float = 600.0 -DEFAULT_STEP_LIMIT: int = 100 -DEFAULT_TOKEN_LIMIT: int = 200_000 -DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3 - - -class AgentStopException(Exception): - """Custom exception to indicate conversation stopping conditions.""" - - -class TokenLimitExceededException(AgentStopException): - """Raised by `Agent.invoke`, when token limit exceeds""" - - def __init__(self, token_limit: float) -> None: - super().__init__(f"Token limit of {token_limit} exceeded.") - - -class StepsLimitExceededException(AgentStopException): - """Raised by `Agent.invoke`, when steps limit exceeds""" - - def __init__(self, steps_limit: int) -> None: - super().__init__(f"Steps limit of {steps_limit} exceeded.") - - -class TimeoutExceededException(AgentStopException): - """Raised by `Agent.invoke`, when timeout exceeds""" - - def __init__(self, timeout_seconds: float) -> None: - super().__init__(f"Timed out after {timeout_seconds} seconds.") - - -class StructuredOutputRetryLimitExceededException(AgentStopException): - """Raised by `Agent.invoke`, when structured output retry limit exceeds""" - - def __init__(self, retry_count: int) -> None: - super().__init__(f"Structured output retry limit of {retry_count} exceeded") def before_model( @@ -132,123 +93,3 @@ async def agent_middleware( return handler_response return _Middleware() - - -class TokenLimitMiddleware(AgentMiddleware): - """Stops agent execution when the token count of messages passed to the model exceeds the given limit.""" - - _limit: int - - def __init__(self, limit: int) -> None: - self._limit = limit - - @override - async def model_middleware( - self, - request: ModelRequest, - handler: ModelMiddlewareHandler, - ) -> ModelResponse: - if request.state.token_count >= self._limit: - raise TokenLimitExceededException(token_limit=self._limit) - return await handler(request) - - -class StepLimitMiddleware(AgentMiddleware): - """Stops agent execution when the number of steps taken reaches the given limit.""" - - _limit: int - - def __init__(self, limit: int) -> None: - self._limit = limit - - @override - async def model_middleware( - self, - request: ModelRequest, - handler: ModelMiddlewareHandler, - ) -> ModelResponse: - if request.state.total_steps >= self._limit: - raise StepsLimitExceededException(steps_limit=self._limit) - return await handler(request) - - -class TimeoutLimitMiddleware(AgentMiddleware): - """Stops agent execution when wall-clock time within an invoke exceeds the given seconds. - - The deadline resets on every invoke call - it measures time from the start of - each invocation, not from agent construction. - - Do not share instances between agents. - """ - - _seconds: float - _deadline_per_thread_id: dict[str, float] - - def __init__(self, seconds: float) -> None: - self._seconds = seconds - self._deadline_per_thread_id = {} - - @override - async def agent_middleware( - self, - request: AgentRequest, - handler: AgentMiddlewareHandler, - ) -> AgentResponse[Any | None]: - try: - # Agent loop starting. - self._deadline_per_thread_id[request.thread_id] = ( - monotonic() + self._seconds - ) - return await handler(request) - finally: - del self._deadline_per_thread_id[request.thread_id] # don't leak memory - - @override - async def model_middleware( - self, - request: ModelRequest, - handler: ModelMiddlewareHandler, - ) -> ModelResponse: - if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]: - raise TimeoutExceededException(timeout_seconds=self._seconds) - return await handler(request) - - -class StructuredOutputRetryLimitMiddleware(AgentMiddleware): - """Stops agent execution when the agent exceeds structured output - retry limit during a single agent loop invocation. Pass 0 to disable retries. - """ - - _limit: int - _retries_per_thread_id: dict[str, int] - - def __init__(self, limit: int) -> None: - self._limit = limit - self._retries_per_thread_id = {} - - @override - async def agent_middleware( - self, - request: AgentRequest, - handler: AgentMiddlewareHandler, - ) -> AgentResponse[Any | None]: - try: - # Agent loop starting. - self._retries_per_thread_id[request.thread_id] = 0 - return await handler(request) - finally: - del self._retries_per_thread_id[request.thread_id] # don't leak memory - - @override - async def model_middleware( - self, - request: ModelRequest, - handler: ModelMiddlewareHandler, - ) -> ModelResponse: - try: - return await handler(request) - except StructuredOutputGenerationException: - self._retries_per_thread_id[request.state.thread_id] += 1 - if self._retries_per_thread_id[request.state.thread_id] > self._limit: - raise StructuredOutputRetryLimitExceededException(self._limit) - raise # re-raise, to retry structured output generation diff --git a/splunklib/ai/limits.py b/splunklib/ai/limits.py new file mode 100644 index 00000000..fff534b3 --- /dev/null +++ b/splunklib/ai/limits.py @@ -0,0 +1,184 @@ +# Copyright © 2011-2026 Splunk, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"): you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from time import monotonic +from typing import Any, override + +from splunklib.ai.messages import AgentResponse +from splunklib.ai.middleware import ( + AgentMiddleware, + AgentMiddlewareHandler, + AgentRequest, + ModelMiddlewareHandler, + ModelRequest, + ModelResponse, +) +from splunklib.ai.structured_output import StructuredOutputGenerationException + +DEFAULT_TIMEOUT_SECONDS: float = 600.0 +DEFAULT_STEP_LIMIT: int = 100 +DEFAULT_TOKEN_LIMIT: int = 200_000 +DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3 + + +class AgentStopException(Exception): + """Custom exception to indicate conversation stopping conditions.""" + + +class TokenLimitExceededException(AgentStopException): + """Raised by `Agent.invoke`, when token limit exceeds""" + + def __init__(self, token_limit: int) -> None: + super().__init__(f"Token limit of {token_limit} exceeded.") + + +class StepsLimitExceededException(AgentStopException): + """Raised by `Agent.invoke`, when steps limit exceeds""" + + def __init__(self, steps_limit: int) -> None: + super().__init__(f"Steps limit of {steps_limit} exceeded.") + + +class TimeoutExceededException(AgentStopException): + """Raised by `Agent.invoke`, when timeout exceeds""" + + def __init__(self, timeout_seconds: float) -> None: + super().__init__(f"Timed out after {timeout_seconds} seconds.") + + +class StructuredOutputRetryLimitExceededException(AgentStopException): + """Raised by `Agent.invoke`, when structured output retry limit exceeds""" + + def __init__(self, retry_count: int) -> None: + super().__init__(f"Structured output retry limit of {retry_count} exceeded") + + +class TokenLimitMiddleware(AgentMiddleware): + """Stops agent execution when the token count of messages passed to the model exceeds the given limit.""" + + _limit: int + + def __init__(self, limit: int) -> None: + self._limit = limit + + @override + async def model_middleware( + self, + request: ModelRequest, + handler: ModelMiddlewareHandler, + ) -> ModelResponse: + if request.state.token_count >= self._limit: + raise TokenLimitExceededException(token_limit=self._limit) + return await handler(request) + + +class StepLimitMiddleware(AgentMiddleware): + """Stops agent execution when the number of steps taken reaches the given limit.""" + + _limit: int + + def __init__(self, limit: int) -> None: + self._limit = limit + + @override + async def model_middleware( + self, + request: ModelRequest, + handler: ModelMiddlewareHandler, + ) -> ModelResponse: + if request.state.total_steps >= self._limit: + raise StepsLimitExceededException(steps_limit=self._limit) + return await handler(request) + + +class TimeoutLimitMiddleware(AgentMiddleware): + """Stops agent execution when wall-clock time within an invoke exceeds the given seconds. + + The deadline resets on every invoke call - it measures time from the start of + each invocation, not from agent construction. + + Do not share instances between agents. + """ + + _seconds: float + _deadline_per_thread_id: dict[str, float] + + def __init__(self, seconds: float) -> None: + self._seconds = seconds + self._deadline_per_thread_id = {} + + @override + async def agent_middleware( + self, + request: AgentRequest, + handler: AgentMiddlewareHandler, + ) -> AgentResponse[Any | None]: + try: + # Agent loop starting. + self._deadline_per_thread_id[request.thread_id] = ( + monotonic() + self._seconds + ) + return await handler(request) + finally: + del self._deadline_per_thread_id[request.thread_id] # don't leak memory + + @override + async def model_middleware( + self, + request: ModelRequest, + handler: ModelMiddlewareHandler, + ) -> ModelResponse: + if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]: + raise TimeoutExceededException(timeout_seconds=self._seconds) + return await handler(request) + + +class StructuredOutputRetryLimitMiddleware(AgentMiddleware): + """Stops agent execution when the agent exceeds structured output + retry limit during a single agent loop invocation. Pass 0 to disable retires. + """ + + _limit: int + _retries_per_thread_id: dict[str, int] + + def __init__(self, limit: int) -> None: + self._limit = limit + self._retries_per_thread_id = {} + + @override + async def agent_middleware( + self, + request: AgentRequest, + handler: AgentMiddlewareHandler, + ) -> AgentResponse[Any | None]: + try: + # Agent loop starting. + self._retries_per_thread_id[request.thread_id] = 0 + return await handler(request) + finally: + del self._retries_per_thread_id[request.thread_id] # don't leak memory + + @override + async def model_middleware( + self, + request: ModelRequest, + handler: ModelMiddlewareHandler, + ) -> ModelResponse: + try: + return await handler(request) + except StructuredOutputGenerationException: + self._retries_per_thread_id[request.state.thread_id] += 1 + if self._retries_per_thread_id[request.state.thread_id] > self._limit: + raise StructuredOutputRetryLimitExceededException(self._limit) + raise # re-raise, to retry structured output generation diff --git a/tests/integration/ai/test_hooks.py b/tests/integration/ai/test_hooks.py index 7c63dfad..3da9274b 100644 --- a/tests/integration/ai/test_hooks.py +++ b/tests/integration/ai/test_hooks.py @@ -18,16 +18,18 @@ from splunklib.ai import Agent from splunklib.ai.conversation_store import InMemoryStore from splunklib.ai.hooks import ( + after_agent, + after_model, + before_agent, + before_model, +) +from splunklib.ai.limits import ( StepLimitMiddleware, StepsLimitExceededException, TimeoutExceededException, TimeoutLimitMiddleware, TokenLimitExceededException, TokenLimitMiddleware, - after_agent, - after_model, - before_agent, - before_model, ) from splunklib.ai.messages import AgentResponse, AIMessage, HumanMessage from splunklib.ai.middleware import ( diff --git a/tests/integration/ai/test_structured_output.py b/tests/integration/ai/test_structured_output.py index db3386b4..7bf91be4 100644 --- a/tests/integration/ai/test_structured_output.py +++ b/tests/integration/ai/test_structured_output.py @@ -21,7 +21,7 @@ from pydantic.dataclasses import dataclass from splunklib.ai import Agent -from splunklib.ai.hooks import ( +from splunklib.ai.limits import ( StructuredOutputRetryLimitExceededException, StructuredOutputRetryLimitMiddleware, ) diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index 89ecccee..66957f8d 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -16,7 +16,7 @@ from time import monotonic from splunklib.ai.agent import Agent -from splunklib.ai.hooks import ( +from splunklib.ai.limits import ( DEFAULT_STEP_LIMIT, DEFAULT_TIMEOUT_SECONDS, DEFAULT_TOKEN_LIMIT,