diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py index 4451889..6c92c34 100644 --- a/custom_components/pyscript/decorator.py +++ b/custom_components/pyscript/decorator.py @@ -246,9 +246,12 @@ async def _call(self, data: DispatchData) -> None: # Store HASS Context for this Task Function.store_hass_context(data.hass_context) - result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) - for result_handler_dec in result_handlers: - await result_handler_dec.handle_call_result(data, result) + try: + result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, result) + except Exception as e: + await self.handle_exception(e) async def dispatch(self, data: DispatchData) -> None: """Handle a trigger dispatch: run guards, create a context, and invoke the function.""" diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py index d0d449c..612476c 100644 --- a/custom_components/pyscript/decorators/webhook.py +++ b/custom_components/pyscript/decorators/webhook.py @@ -1,6 +1,9 @@ """Webhook decorator.""" +from __future__ import annotations + import logging +from typing import ClassVar from aiohttp import hdrs import voluptuous as vol @@ -36,6 +39,8 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD local_only: bool methods: set[str] + webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} + async def validate(self): """Validate the webhook trigger configuration.""" await super().validate() @@ -44,7 +49,8 @@ async def validate(self): if len(self.args) == 2: self.create_expression(self.args[1]) - async def _handler(self, hass, webhook_id, request): + @staticmethod + async def _handler(_hass, webhook_id, request): func_args = { "trigger_type": "webhook", "webhook_id": webhook_id, @@ -57,28 +63,50 @@ async def _handler(self, hass, webhook_id, request): payload_multidict = await request.post() func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} - if self.has_expression(): - if not await self.check_expression_vars(func_args): - return - - await self.dispatch(DispatchData(func_args)) + for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy(): + trigger_args = func_args.copy() + if trigger.has_expression(): + if not await trigger.check_expression_vars(trigger_args): + continue + await trigger.dispatch(DispatchData(trigger_args)) + + @staticmethod + def _add_trigger(trigger: WebhookTriggerDecorator) -> None: + webhook_id = trigger.webhook_id + if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers: + webhook.async_register( + trigger.dm.hass, + "pyscript", # DOMAIN + "pyscript", # NAME + webhook_id, + WebhookTriggerDecorator._handler, + local_only=trigger.local_only, + allowed_methods=trigger.methods, + ) + WebhookTriggerDecorator.webhook_id2triggers[webhook_id] = set() + + WebhookTriggerDecorator.webhook_id2triggers[webhook_id].add(trigger) + + @staticmethod + def _remove_trigger(trigger: WebhookTriggerDecorator) -> None: + webhook_id = trigger.webhook_id + triggers = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id) + if not triggers: + return + + triggers.discard(trigger) + if len(triggers) == 0: + webhook.async_unregister(trigger.dm.hass, webhook_id) + del WebhookTriggerDecorator.webhook_id2triggers[webhook_id] async def start(self): """Start the webhook trigger.""" await super().start() - webhook.async_register( - self.dm.hass, - "pyscript", # DOMAIN - "pyscript", # NAME - self.webhook_id, - self._handler, - local_only=self.local_only, - allowed_methods=self.methods, - ) + self._add_trigger(self) _LOGGER.debug("webhook trigger %s listening on id %s", self.dm.name, self.webhook_id) async def stop(self): """Stop the webhook trigger.""" await super().stop() - webhook.async_unregister(self.dm.hass, self.webhook_id) + self._remove_trigger(self) diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py index 483d3a1..45c6f89 100644 --- a/tests/test_decorator_manager.py +++ b/tests/test_decorator_manager.py @@ -271,14 +271,17 @@ def get_name(self) -> str: class DummyCallAstCtx: """Minimal action AstEval stub for manager call tests.""" - def __init__(self, result: object) -> None: + def __init__(self, result: object = None, exc: Exception | None = None) -> None: """Initialize the dummy action context.""" self.result = result + self.exc = exc self.calls: list[tuple[object, object, dict]] = [] async def call_func(self, func: object, func_name: object, **kwargs: object) -> object: - """Record the function call and return the configured result.""" + """Record the function call and return or raise the configured result.""" self.calls.append((func, func_name, kwargs)) + if self.exc is not None: + raise self.exc return self.result @@ -578,6 +581,24 @@ def event_listener(event): store_hass_context.assert_called_once_with(hass_context) +@pytest.mark.asyncio +async def test_function_decorator_manager_logs_call_exception(hass): + """Failed decorated function calls should be routed through the manager.""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("decorated call failed")) + + await call_function_manager( + manager, + make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="call-parent")), + ) + + assert call_ast_ctx.calls == [(manager.eval_func, None, {"arg1": 1})] + assert len(ast_ctx.logged_exceptions) == 1 + assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed" + + def test_decorator_registry_register_requires_name(): """Registry should reject decorators without a declared name."""