Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 107 additions & 49 deletions src/fastcs/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from collections import Counter
from collections.abc import Sequence
from copy import deepcopy
from typing import _GenericAlias, get_args, get_origin, get_type_hints # type: ignore
from typing import ( # type: ignore
TypeVar,
_GenericAlias, # type: ignore
get_args,
get_origin,
get_type_hints,
)

from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute
from fastcs.logging import logger
from fastcs.methods import Command, Scan, UnboundCommand, UnboundScan
from fastcs.methods import Command, Method, Scan, UnboundCommand, UnboundScan
from fastcs.tracer import Tracer

T = TypeVar("T")


class BaseController(Tracer):
"""Base class for controllers
Expand Down Expand Up @@ -49,6 +57,7 @@ def __init__(
self.__scan_methods: dict[str, Scan] = {}

self.__hinted_attributes: dict[str, HintedAttribute] = {}
self.__hinted_methods: dict[str, type[Method]] = {}
self.__hinted_sub_controllers: dict[str, type[BaseController]] = {}
self._find_type_hints()

Expand Down Expand Up @@ -85,6 +94,9 @@ def _find_type_hints(self):
elif isinstance(hint, type) and issubclass(hint, BaseController):
self.__hinted_sub_controllers[name] = hint

elif isinstance(hint, type) and issubclass(hint, Method):
self.__hinted_methods[name] = hint

def _bind_attrs(self) -> None:
"""Search for Attributes and Methods to bind them to this instance.

Expand Down Expand Up @@ -166,47 +178,70 @@ def post_initialise(self):
self._connect_attribute_ios()

def _validate_type_hints(self):
"""Validate all `Attribute` and `Controller` type-hints were introspected"""
"""Validate all type-hints were introspected"""
for name in self.__hinted_attributes:
self._validate_hinted_attribute(name)

for name in self.__hinted_sub_controllers:
self._validate_hinted_controller(name)

for name in self.__hinted_methods:
self._validate_hinted_method(name)

for subcontroller in self.sub_controllers.values():
subcontroller._validate_type_hints() # noqa: SLF001

def _validate_hinted_member(self, name: str, expected_type: type[T]) -> T:
"""Validate that a hinted member exists on the controller"""
member = getattr(self, name, None)
if member is None or not isinstance(member, expected_type):
raise RuntimeError()
return member

def _validate_hinted_method(self, name: str):
"""Check that a `Method` with the given name exists on the controller"""
try:
method = self._validate_hinted_member(name, Method)
except RuntimeError:
Comment on lines +201 to +205
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate against the declared hinted method type, not base Method.

Line 204 validates hinted methods with Method, which can let a plain Method instance pass for hints like Command/Scan if assigned directly. Use self.__hinted_methods[name] so post-initialise validation enforces the actual hinted subtype.

🔧 Proposed fix
     def _validate_hinted_method(self, name: str):
         """Check that a `Method` with the given name exists on the controller"""
         try:
-            method = self._validate_hinted_member(name, Method)
+            expected_type = self.__hinted_methods[name]
+            method = self._validate_hinted_member(name, expected_type)
         except RuntimeError:
             raise RuntimeError(
                 f"Controller `{self.__class__.__name__}` failed to introspect "
                 f"hinted method `{name}` during initialisation"
             ) from None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/fastcs/controllers/base_controller.py` around lines 201 - 205, The
validation currently uses the base Method type in _validate_hinted_method which
lets a plain Method satisfy hints for subtypes like Command/Scan; change it to
look up the expected hinted type from the controller's mapping and validate
against that actual type: fetch hinted_type = self.__hinted_methods[name] (or
equivalent access to the stored hint), then call
self._validate_hinted_member(name, hinted_type) so post-initialise checks
enforce the declared subtype (e.g. Command/Scan) rather than the base Method.

raise RuntimeError(
f"Controller `{self.__class__.__name__}` failed to introspect "
f"hinted method `{name}` during initialisation"
) from None

logger.debug(
"Validated hinted method", name=name, controller=self, method=method
)

def _validate_hinted_attribute(self, name: str):
"""Check that an `Attribute` with the given name exists on the controller"""
attr = getattr(self, name, None)
if attr is None or not isinstance(attr, Attribute):
try:
attr = self._validate_hinted_member(name, Attribute)
except RuntimeError:
raise RuntimeError(
f"Controller `{self.__class__.__name__}` failed to introspect "
f"hinted attribute `{name}` during initialisation"
)
else:
logger.debug(
"Validated hinted attribute",
name=name,
controller=self,
attribute=attr,
)
) from None

logger.debug(
"Validated hinted attribute", name=name, controller=self, attribute=attr
)

def _validate_hinted_controller(self, name: str):
"""Check that a sub controller with the given name exists on the controller"""
controller = getattr(self, name, None)
if controller is None or not isinstance(controller, BaseController):
try:
controller = self._validate_hinted_member(name, BaseController)
except RuntimeError:
raise RuntimeError(
f"Controller `{self.__class__.__name__}` failed to introspect "
f"hinted controller `{name}` during initialisation"
)
else:
logger.debug(
"Validated hinted sub controller",
name=name,
controller=self,
sub_controller=controller,
)
) from None

logger.debug(
"Validated hinted sub controller",
name=name,
controller=self,
sub_controller=controller,
)

def _connect_attribute_ios(self) -> None:
"""Connect ``Attribute`` callbacks to ``AttributeIO``s"""
Expand Down Expand Up @@ -243,14 +278,27 @@ def set_path(self, path: list[str]):
for attribute in self.__attributes.values():
attribute.set_path(path)

def _check_for_name_clash(self, name: str):
namespaces = {
"attribute": self.__attributes,
"sub controller": self.__sub_controllers,
"scan method": self.__scan_methods,
"command method": self.__command_methods,
}

for kind, namespace in namespaces.items():
if name in namespace:
raise ValueError(
f"Controller {self} has existing {kind} {name}: {namespace[name]}"
)

def add_attribute(self, name, attr: Attribute):
if name in self.__attributes:
raise ValueError(
f"Cannot add attribute {attr}. "
f"Controller {self} has has existing attribute {name}: "
f"{self.__attributes[name]}"
)
elif name in self.__hinted_attributes:
try:
self._check_for_name_clash(name)
except ValueError as exc:
raise ValueError(f"Cannot add attribute {attr}.") from exc

if name in self.__hinted_attributes:
hint = self.__hinted_attributes[name]
if not isinstance(attr, hint.attr_type):
raise RuntimeError(
Expand All @@ -265,12 +313,6 @@ def add_attribute(self, name, attr: Attribute):
f"Expected '{hint.dtype.__name__}', "
f"got '{attr.datatype.dtype.__name__}'."
)
elif name in self.__sub_controllers.keys():
raise ValueError(
f"Cannot add attribute {attr}. "
f"Controller {self} has existing sub controller {name}: "
f"{self.__sub_controllers[name]}"
)

attr.set_name(name)
attr.set_path(self.path)
Expand All @@ -282,13 +324,12 @@ def attributes(self) -> dict[str, Attribute]:
return self.__attributes

def add_sub_controller(self, name: str, sub_controller: BaseController):
if name in self.__sub_controllers.keys():
raise ValueError(
f"Cannot add sub controller {sub_controller}. "
f"Controller {self} has existing sub controller {name}: "
f"{self.__sub_controllers[name]}"
)
elif name in self.__hinted_sub_controllers:
try:
self._check_for_name_clash(name)
except ValueError as exc:
raise ValueError(f"Cannot add sub controller {sub_controller}.") from exc

if name in self.__hinted_sub_controllers:
hint = self.__hinted_sub_controllers[name]
if not isinstance(sub_controller, hint):
raise RuntimeError(
Expand All @@ -297,12 +338,6 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
f"Expected '{hint.__name__}' got "
f"'{sub_controller.__class__.__name__}'."
)
elif name in self.__attributes:
raise ValueError(
f"Cannot add sub controller {sub_controller}. "
f"Controller {self} has existing attribute {name}: "
f"{self.__attributes[name]}"
)

sub_controller.set_path(self.path + [name])
self.__sub_controllers[name] = sub_controller
Expand All @@ -315,7 +350,24 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
def sub_controllers(self) -> dict[str, BaseController]:
return self.__sub_controllers

def _validate_method(self, name: str, method: Method):
if name in self.__hinted_methods:
hint = self.__hinted_methods[name]
if not isinstance(method, hint):
raise RuntimeError(
f"Controller '{self.__class__.__name__}' introspection of "
f"hinted method '{name}' does not match defined type. "
f"Expected '{hint.__name__}' got "
f"'{method.__class__.__name__}'."
)

def add_command(self, name: str, command: Command):
try:
self._check_for_name_clash(name)
self._validate_method(name, command)
except (ValueError, RuntimeError) as exc:
raise exc.__class__(f"Cannot add command method {command}.") from exc

self.__command_methods[name] = command
super().__setattr__(name, command)

Expand All @@ -324,6 +376,12 @@ def command_methods(self) -> dict[str, Command]:
return self.__command_methods

def add_scan(self, name: str, scan: Scan):
try:
self._check_for_name_clash(name)
self._validate_method(name, scan)
except (ValueError, RuntimeError) as exc:
raise exc.__class__(f"Cannot add scan method {scan}.") from exc

self.__scan_methods[name] = scan
super().__setattr__(name, scan)

Expand Down
1 change: 1 addition & 0 deletions src/fastcs/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .command import CommandCallback as CommandCallback
from .command import UnboundCommand as UnboundCommand
from .command import command as command
from .method import Method as Method
from .scan import Scan as Scan
from .scan import ScanCallback as ScanCallback
from .scan import UnboundScan as UnboundScan
Expand Down
63 changes: 44 additions & 19 deletions tests/test_controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastcs.attributes import AttrR, AttrRW
from fastcs.controllers import Controller, ControllerVector
from fastcs.datatypes import Enum, Float, Int
from fastcs.methods import Command, Scan


def test_controller_nesting():
Expand All @@ -20,7 +21,7 @@ def test_controller_nesting():
assert controller.sub_controllers == {"a": sub_controller}
assert sub_controller.sub_controllers == {"b": sub_sub_controller}

with pytest.raises(ValueError, match=r"existing sub controller"):
with pytest.raises(ValueError, match=r"Cannot add sub controller"):
controller.a = Controller()

with pytest.raises(ValueError, match=r"already registered"):
Expand Down Expand Up @@ -76,33 +77,39 @@ def test_attribute_parsing():
}


def test_conflicting_attributes_and_controllers():
async def noop() -> None:
pass


@pytest.mark.parametrize(
"member_name, member_value, expected_error",
[
("attr", AttrR(Float()), r"Cannot add attribute"),
("attr", Controller(), r"Cannot add sub controller"),
("attr", Command(noop), r"Cannot add command"),
("sub_controller", AttrR(Int()), r"Cannot add attribute"),
("sub_controller", Controller(), r"Cannot add sub controller"),
("sub_controller", Command(noop), r"Cannot add command"),
("cmd", AttrR(Int()), r"Cannot add attribute"),
("cmd", Controller(), r"Cannot add sub controller"),
("cmd", Command(noop), r"Cannot add command"),
],
)
def test_conflicting_attributes_and_controllers_and_commands(
member_name, member_value, expected_error
):
class ConflictingController(Controller):
attr = AttrR(Int())
cmd = Command(noop)

def __init__(self):
super().__init__()
self.sub_controller = Controller()

controller = ConflictingController()

with pytest.raises(ValueError, match=r"Cannot add attribute .* existing attribute"):
controller.attr = AttrR(Float()) # pyright: ignore[reportAttributeAccessIssue]

with pytest.raises(
ValueError, match=r"Cannot add sub controller .* existing attribute"
):
controller.attr = Controller() # pyright: ignore[reportAttributeAccessIssue]

with pytest.raises(
ValueError, match=r"Cannot add sub controller .* existing sub controller"
):
controller.sub_controller = Controller()

with pytest.raises(
ValueError, match=r"Cannot add attribute .* existing sub controller"
):
controller.sub_controller = AttrR(Int()) # pyright: ignore[reportAttributeAccessIssue]
with pytest.raises(ValueError, match=expected_error):
setattr(controller, member_name, member_value)


def test_controller_raises_error_if_passed_numeric_sub_controller_name():
Expand Down Expand Up @@ -203,3 +210,21 @@ class HintedController(Controller):

controller.add_sub_controller("child", SomeSubController())
controller._validate_type_hints()


@pytest.mark.asyncio
async def test_method_hint_validation():
class HintedController(Controller):
method: Scan

controller = HintedController()

with pytest.raises(RuntimeError, match="failed to introspect hinted method"):
controller._validate_type_hints()

with pytest.raises(RuntimeError, match="Cannot add command method"):
controller.add_command("method", Command(noop))

controller.add_scan("method", Scan(fn=noop, period=0.1))

controller._validate_type_hints()