Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions custom_components/pyscript/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,12 @@ async def _call(self, data: DispatchData) -> None:
# Store HASS Context for this Task
Function.store_hass_context(data.hass_context)

result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args)
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, result)
try:
result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args)
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, result)
except Exception as e:
await self.handle_exception(e)

async def dispatch(self, data: DispatchData) -> None:
"""Handle a trigger dispatch: run guards, create a context, and invoke the function."""
Expand Down
60 changes: 44 additions & 16 deletions custom_components/pyscript/decorators/webhook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Webhook decorator."""

from __future__ import annotations

import logging
from typing import ClassVar

from aiohttp import hdrs
import voluptuous as vol
Expand Down Expand Up @@ -36,6 +39,8 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD
local_only: bool
methods: set[str]

webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {}

async def validate(self):
"""Validate the webhook trigger configuration."""
await super().validate()
Expand All @@ -44,7 +49,8 @@ async def validate(self):
if len(self.args) == 2:
self.create_expression(self.args[1])

async def _handler(self, hass, webhook_id, request):
@staticmethod
async def _handler(_hass, webhook_id, request):
func_args = {
"trigger_type": "webhook",
"webhook_id": webhook_id,
Expand All @@ -57,28 +63,50 @@ async def _handler(self, hass, webhook_id, request):
payload_multidict = await request.post()
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}

if self.has_expression():
if not await self.check_expression_vars(func_args):
return

await self.dispatch(DispatchData(func_args))
for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy():
trigger_args = func_args.copy()
if trigger.has_expression():
if not await trigger.check_expression_vars(trigger_args):
continue
await trigger.dispatch(DispatchData(trigger_args))

@staticmethod
def _add_trigger(trigger: WebhookTriggerDecorator) -> None:
webhook_id = trigger.webhook_id
if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers:
webhook.async_register(
trigger.dm.hass,
"pyscript", # DOMAIN
"pyscript", # NAME
webhook_id,
WebhookTriggerDecorator._handler,
local_only=trigger.local_only,
allowed_methods=trigger.methods,
)
WebhookTriggerDecorator.webhook_id2triggers[webhook_id] = set()

WebhookTriggerDecorator.webhook_id2triggers[webhook_id].add(trigger)

@staticmethod
def _remove_trigger(trigger: WebhookTriggerDecorator) -> None:
webhook_id = trigger.webhook_id
triggers = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id)
if not triggers:
return

triggers.discard(trigger)
if len(triggers) == 0:
webhook.async_unregister(trigger.dm.hass, webhook_id)
del WebhookTriggerDecorator.webhook_id2triggers[webhook_id]

async def start(self):
"""Start the webhook trigger."""
await super().start()
webhook.async_register(
self.dm.hass,
"pyscript", # DOMAIN
"pyscript", # NAME
self.webhook_id,
self._handler,
local_only=self.local_only,
allowed_methods=self.methods,
)
self._add_trigger(self)

_LOGGER.debug("webhook trigger %s listening on id %s", self.dm.name, self.webhook_id)

async def stop(self):
"""Stop the webhook trigger."""
await super().stop()
webhook.async_unregister(self.dm.hass, self.webhook_id)
self._remove_trigger(self)
25 changes: 23 additions & 2 deletions tests/test_decorator_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,17 @@ def get_name(self) -> str:
class DummyCallAstCtx:
"""Minimal action AstEval stub for manager call tests."""

def __init__(self, result: object) -> None:
def __init__(self, result: object = None, exc: Exception | None = None) -> None:
"""Initialize the dummy action context."""
self.result = result
self.exc = exc
self.calls: list[tuple[object, object, dict]] = []

async def call_func(self, func: object, func_name: object, **kwargs: object) -> object:
"""Record the function call and return the configured result."""
"""Record the function call and return or raise the configured result."""
self.calls.append((func, func_name, kwargs))
if self.exc is not None:
raise self.exc
return self.result


Expand Down Expand Up @@ -578,6 +581,24 @@ def event_listener(event):
store_hass_context.assert_called_once_with(hass_context)


@pytest.mark.asyncio
async def test_function_decorator_manager_logs_call_exception(hass):
"""Failed decorated function calls should be routed through the manager."""
DecoratorManager.hass = hass
ast_ctx = DummyAstCtx()
manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar())
call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("decorated call failed"))

await call_function_manager(
manager,
make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="call-parent")),
)

assert call_ast_ctx.calls == [(manager.eval_func, None, {"arg1": 1})]
assert len(ast_ctx.logged_exceptions) == 1
assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed"


def test_decorator_registry_register_requires_name():
"""Registry should reject decorators without a declared name."""

Expand Down
Loading