diff --git a/bionetgen/modelapi/model.py b/bionetgen/modelapi/model.py index afe15e00..2b08a018 100644 --- a/bionetgen/modelapi/model.py +++ b/bionetgen/modelapi/model.py @@ -175,11 +175,7 @@ def add_block(self, block): Adds the given block object to the model, uses the name of the block object to determine what block it is """ - bname = block.name.replace(" ", "_") - # TODO: fix this exception - if bname == "reaction_rules": - bname = "rules" - block_adder = getattr(self, "add_{}_block".format(bname)) + block_adder = self._resolve_block_adder(block.name) block_adder(block) def add_empty_block(self, block_name): @@ -187,13 +183,39 @@ def add_empty_block(self, block_name): Makes an empty block object from a given block name and adds it to the model object. """ - bname = block_name.replace(" ", "_") - # TODO: fix this exception - if bname == "reaction_rules": - bname = "rules" - block_adder = getattr(self, "add_{}_block".format(bname)) + block_adder = self._resolve_block_adder(block_name) block_adder() + def _resolve_block_adder(self, block_name): + """ + Resolve supported block names to block adders. + + Block names are normalized by replacing spaces with underscores, and + the historical ``reaction_rules`` alias continues to map to ``rules``. + """ + normalized_name = block_name.replace(" ", "_") + block_adders = { + "parameters": self.add_parameters_block, + "compartments": self.add_compartments_block, + "molecule_types": self.add_molecule_types_block, + "species": self.add_species_block, + "observables": self.add_observables_block, + "functions": self.add_functions_block, + "energy_patterns": self.add_energy_patterns_block, + "population_maps": self.add_population_maps_block, + "rules": self.add_rules_block, + "reaction_rules": self.add_rules_block, + "protocol": self.add_protocol_block, + "actions": self.add_actions_block, + } + if normalized_name not in block_adders: + supported_names = ", ".join(block_adders) + raise ValueError( + f"Unsupported block name '{block_name}'. " + f"Supported block names: {supported_names}" + ) + return block_adders[normalized_name] + def add_parameters_block(self, block=None): """ Adds a parameters block to the model object. diff --git a/bionetgen/network/network.py b/bionetgen/network/network.py index 74f375b4..616fd6eb 100644 --- a/bionetgen/network/network.py +++ b/bionetgen/network/network.py @@ -109,17 +109,35 @@ def __iter__(self): return active_ordered_blocks.__iter__() def add_block(self, block): - bname = block.name.replace(" ", "_") - # TODO: fix this exception - block_adder = getattr(self, "add_{}_block".format(bname)) + block_adder = self._resolve_block_adder(block.name) block_adder(block) def add_empty_block(self, block_name): - bname = block_name.replace(" ", "_") - # TODO: fix this exception - block_adder = getattr(self, "add_{}_block".format(bname)) + block_adder = self._resolve_block_adder(block_name) block_adder() + def _resolve_block_adder(self, block_name): + """ + Resolve supported block names to block adders. + + Block names are normalized by replacing spaces with underscores before + dispatch so callers can use parser-style or attribute-style names. + """ + normalized_name = block_name.replace(" ", "_") + block_adders = { + "parameters": self.add_parameters_block, + "species": self.add_species_block, + "reactions": self.add_reactions_block, + "groups": self.add_groups_block, + } + if normalized_name not in block_adders: + supported_names = ", ".join(block_adders) + raise ValueError( + f"Unsupported block name '{block_name}'. " + f"Supported block names: {supported_names}" + ) + return block_adders[normalized_name] + def add_parameters_block(self, block=None): if block is not None: # TODO: Transition to BNGErrors and logging diff --git a/tests/test_block_dispatch_validation.py b/tests/test_block_dispatch_validation.py new file mode 100644 index 00000000..2e1984c0 --- /dev/null +++ b/tests/test_block_dispatch_validation.py @@ -0,0 +1,182 @@ +"""Focused tests for model and network block dispatch validation.""" + +import pytest + +from bionetgen.modelapi.blocks import ( + ActionBlock, + CompartmentBlock, + EnergyPatternBlock, + FunctionBlock, + MoleculeTypeBlock, + ObservableBlock, + ParameterBlock, + PopulationMapBlock, + ProtocolBlock, + RuleBlock, + SpeciesBlock, +) +from bionetgen.network.blocks import ( + NetworkGroupBlock, + NetworkParameterBlock, + NetworkReactionBlock, + NetworkSpeciesBlock, +) + + +def _make_model_bypass_init(): + from bionetgen.modelapi.model import bngmodel + + model = object.__new__(bngmodel) + model.active_blocks = [] + model._block_order = [ + "parameters", + "compartments", + "molecule_types", + "species", + "observables", + "functions", + "energy_patterns", + "population_maps", + "rules", + "protocol", + "actions", + ] + model.model_name = "test_model" + model.model_path = "/fake/test.bngl" + model.parameters = ParameterBlock() + model.compartments = CompartmentBlock() + model.molecule_types = MoleculeTypeBlock() + model.species = SpeciesBlock() + model.observables = ObservableBlock() + model.functions = FunctionBlock() + model.energy_patterns = EnergyPatternBlock() + model.population_maps = PopulationMapBlock() + model.rules = RuleBlock() + model.protocol = ProtocolBlock() + model.actions = ActionBlock() + return model + + +def _make_network_bypass_init(): + from bionetgen.network.network import Network + + net = object.__new__(Network) + net.active_blocks = [] + net.block_order = ["parameters", "species", "reactions", "groups"] + net.network_name = "test" + net.parameters = NetworkParameterBlock() + net.species = NetworkSpeciesBlock() + net.reactions = NetworkReactionBlock() + net.groups = NetworkGroupBlock() + return net + + +@pytest.mark.parametrize( + ("block_cls", "attr_name"), + [ + (ParameterBlock, "parameters"), + (RuleBlock, "rules"), + (ProtocolBlock, "protocol"), + ], +) +def test_model_add_block_dispatches_supported_block(block_cls, attr_name): + model = _make_model_bypass_init() + block = block_cls() + + model.add_block(block) + + assert getattr(model, attr_name) is block + assert attr_name in model.active_blocks + + +@pytest.mark.parametrize( + ("block_name", "attr_name", "block_cls"), + [ + ("observables", "observables", ObservableBlock), + ("reaction_rules", "rules", RuleBlock), + ("protocol", "protocol", ProtocolBlock), + ], +) +def test_model_add_empty_block_dispatches_supported_name( + block_name, attr_name, block_cls +): + model = _make_model_bypass_init() + delattr(model, attr_name) + + model.add_empty_block(block_name) + + assert isinstance(getattr(model, attr_name), block_cls) + + +def test_model_add_block_invalid_name_raises_value_error(): + model = _make_model_bypass_init() + + class FakeBlock: + name = "not a block" + + with pytest.raises(ValueError, match="Unsupported block name 'not a block'"): + model.add_block(FakeBlock()) + + assert "not_a_block" not in model.active_blocks + assert not hasattr(model, "not_a_block") + + +def test_model_add_empty_block_invalid_name_raises_value_error(): + model = _make_model_bypass_init() + + with pytest.raises(ValueError, match="Unsupported block name 'not a block'"): + model.add_empty_block("not a block") + + assert "not_a_block" not in model.active_blocks + assert not hasattr(model, "not_a_block") + + +@pytest.mark.parametrize( + ("block_cls", "attr_name"), + [ + (NetworkParameterBlock, "parameters"), + (NetworkSpeciesBlock, "species"), + (NetworkReactionBlock, "reactions"), + (NetworkGroupBlock, "groups"), + ], +) +def test_network_add_block_dispatches_supported_block(block_cls, attr_name): + net = _make_network_bypass_init() + block = block_cls() + + net.add_block(block) + + assert getattr(net, attr_name) is block + assert attr_name in net.active_blocks + + +def test_network_add_empty_block_dispatches_supported_name(): + net = _make_network_bypass_init() + delattr(net, "groups") + + net.add_empty_block("groups") + + assert isinstance(net.groups, NetworkGroupBlock) + + +def test_network_add_block_invalid_name_raises_value_error(): + net = _make_network_bypass_init() + + class FakeBlock: + name = "not a block" + + with pytest.raises(ValueError, match="Unsupported block name 'not a block'"): + net.add_block(FakeBlock()) + + assert "not_a_block" not in net.active_blocks + assert not hasattr(net, "not_a_block") + + +def test_network_add_empty_block_invalid_name_raises_value_error(): + net = _make_network_bypass_init() + + with pytest.raises(ValueError, match="Unsupported block name 'not a block'"): + net.add_empty_block("not a block") + + assert "not_a_block" not in net.active_blocks + assert not hasattr(net, "not_a_block")