From 0056a205b7d7d347b1a7458af01bb63999062e4b Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Mon, 26 Jan 2026 11:25:35 +0100 Subject: [PATCH 1/3] Conditions refactoring (#758) --- pina/_src/condition/batch_manager.py | 43 ++ pina/_src/condition/condition.py | 22 +- pina/_src/condition/condition_base.py | 127 ++++++ pina/_src/condition/condition_interface.py | 99 +---- pina/_src/condition/data_condition.py | 127 +++--- pina/_src/condition/data_manager.py | 348 +++++++++++++++ .../condition/domain_equation_condition.py | 81 ++-- .../condition/input_equation_condition.py | 146 +++---- pina/_src/condition/input_target_condition.py | 218 +++------- pina/_src/problem/abstract_problem.py | 9 +- pina/condition/__init__.py | 30 +- tests/test_condition.py | 154 ------- tests/test_condition/test_data_condition.py | 332 ++++++++++++++ .../test_domain_equation_condition.py | 29 ++ .../test_input_equation_condition.py | 79 ++++ .../test_input_target_condition.py | 409 ++++++++++++++++++ tests/test_data_manager.py | 137 ++++++ .../test_ensemble_supervised_solver.py | 3 +- tests/test_solver/test_supervised_solver.py | 3 +- 19 files changed, 1778 insertions(+), 618 deletions(-) create mode 100644 pina/_src/condition/batch_manager.py create mode 100644 pina/_src/condition/condition_base.py create mode 100644 pina/_src/condition/data_manager.py delete mode 100644 tests/test_condition.py create mode 100644 tests/test_condition/test_data_condition.py create mode 100644 tests/test_condition/test_domain_equation_condition.py create mode 100644 tests/test_condition/test_input_equation_condition.py create mode 100644 tests/test_condition/test_input_target_condition.py create mode 100644 tests/test_data_manager.py diff --git a/pina/_src/condition/batch_manager.py b/pina/_src/condition/batch_manager.py new file mode 100644 index 000000000..105eec6eb --- /dev/null +++ b/pina/_src/condition/batch_manager.py @@ -0,0 +1,43 @@ +""" +Module for managing batches of data with device transfer capabilities. +""" + + +class _BatchManager(dict): + """ + A dictionary-based batch manager that supports dot-notation + and moving tensors to devices. + """ + + def to(self, device): + """ + Move all tensors in the batch to the specified device. + + :param device: The target device. + :type device: torch.device | str + :return: The updated batch manager. + :rtype: _BatchManager + """ + for key, value in self.items(): + if hasattr(value, "to"): + moved_value = value.to(device) + self[key] = moved_value # Updates both dict and attribute + return self + + def __getattribute__(self, name): + """ + Alias attribute access to dictionary keys. + + :param str name: The name of the attribute to retrieve. + :return: The value associated with the attribute name. + :rtype: Any + """ + try: + return super().__getattribute__(name) + except AttributeError: + try: + return self[name] + except KeyError: + raise AttributeError( + f"'BatchManager' object has no attribute '{name}'" + ) diff --git a/pina/_src/condition/condition.py b/pina/_src/condition/condition.py index db2a666d8..8b2c814ba 100644 --- a/pina/_src/condition/condition.py +++ b/pina/_src/condition/condition.py @@ -88,12 +88,12 @@ class Condition: """ # Combine all possible keyword arguments from the different Condition types - __slots__ = list( + available_kwargs = list( set( - InputTargetCondition.__slots__ - + InputEquationCondition.__slots__ - + DomainEquationCondition.__slots__ - + DataCondition.__slots__ + InputTargetCondition.__fields__ + + InputEquationCondition.__fields__ + + DomainEquationCondition.__fields__ + + DataCondition.__fields__ ) ) @@ -114,28 +114,28 @@ def __new__(cls, *args, **kwargs): if len(args) != 0: raise ValueError( "Condition takes only the following keyword " - f"arguments: {Condition.__slots__}." + f"arguments: {Condition.available_kwargs}." ) # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) # Input - Target Condition - if sorted_keys == sorted(InputTargetCondition.__slots__): + if sorted_keys == sorted(InputTargetCondition.__fields__): return InputTargetCondition(**kwargs) # Input - Equation Condition - if sorted_keys == sorted(InputEquationCondition.__slots__): + if sorted_keys == sorted(InputEquationCondition.__fields__): return InputEquationCondition(**kwargs) # Domain - Equation Condition - if sorted_keys == sorted(DomainEquationCondition.__slots__): + if sorted_keys == sorted(DomainEquationCondition.__fields__): return DomainEquationCondition(**kwargs) # Data Condition if ( - sorted_keys == sorted(DataCondition.__slots__) - or sorted_keys[0] == DataCondition.__slots__[0] + sorted_keys == sorted(DataCondition.__fields__) + or sorted_keys[0] == DataCondition.__fields__[0] ): return DataCondition(**kwargs) diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py new file mode 100644 index 000000000..b8290d717 --- /dev/null +++ b/pina/_src/condition/condition_base.py @@ -0,0 +1,127 @@ +""" +Base class for conditions. +""" + +from functools import partial +import torch +from torch_geometric.data import Batch +from torch.utils.data import DataLoader +from pina._src.condition.condition_interface import ConditionInterface +from pina._src.core.graph import LabelBatch +from pina._src.core.label_tensor import LabelTensor + + +class ConditionBase(ConditionInterface): + """ + Base abstract class for all conditions in PINA. + This class provides common functionality for handling data storage, + batching, and interaction with the associated problem. + """ + + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + """ + Initialization of the :class:`ConditionBase` class. + + :param kwargs: Keyword arguments representing the data to be stored. + """ + super().__init__() + self.data = self.store_data(**kwargs) + + @property + def problem(self): + """ + Return the problem associated with this condition. + + :return: Problem associated with this condition. + :rtype: ~pina.problem.abstract_problem.AbstractProblem + """ + return self._problem + + @problem.setter + def problem(self, value): + """ + Set the problem associated with this condition. + + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to associate with this condition + """ + self._problem = value + + def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ + return len(self.data) + + def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param idx: Index(es) of the data point(s) to retrieve. + :type idx: int | list[int] + :return: Data point(s) at the specified index. + """ + return self.data[idx] + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function for automatic batching to be used in DataLoader. + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + if not batch: + return {} + instance_class = batch[0].__class__ + return instance_class.create_batch(batch) + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for custom batching to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :param condition: The condition instance. + :type condition: ConditionBase + :return: A collated batch. + :rtype: dict + """ + data = condition.data[batch].to_batch() + return data + + def create_dataloader( + self, dataset, batch_size, shuffle, automatic_batching + ): + """ + Create a DataLoader for the condition. + + :param int batch_size: The batch size for the DataLoader. + :param bool shuffle: Whether to shuffle the data. Default is ``False``. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + if batch_size == len(dataset): + pass # will be updated in the near future + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + ) diff --git a/pina/_src/condition/condition_interface.py b/pina/_src/condition/condition_interface.py index 509ac2fc3..229b9a025 100644 --- a/pina/_src/condition/condition_interface.py +++ b/pina/_src/condition/condition_interface.py @@ -1,9 +1,6 @@ """Module for the Condition interface.""" -from abc import ABCMeta -from torch_geometric.data import Data -from pina._src.core.label_tensor import LabelTensor -from pina._src.core.graph import Graph +from abc import ABCMeta, abstractmethod class ConditionInterface(metaclass=ABCMeta): @@ -15,13 +12,14 @@ class ConditionInterface(metaclass=ABCMeta): description of all available conditions and how to instantiate them. """ - def __init__(self): + @abstractmethod + def __init__(self, **kwargs): """ Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property + @abstractmethod def problem(self): """ Return the problem associated with this condition. @@ -29,9 +27,9 @@ def problem(self): :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ - return self._problem @problem.setter + @abstractmethod def problem(self, value): """ Set the problem associated with this condition. @@ -39,88 +37,21 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ - self._problem = value - @staticmethod - def _check_graph_list_consistency(data_list): + @abstractmethod + def __len__(self): """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: + Return the number of data points in the condition. - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. + :return: Number of data points. + :rtype: int """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - def __getattribute__(self, name): + @abstractmethod + def __getitem__(self, idx): """ - Get an attribute from the object. + Return the data point(s) at the specified index. - :param str name: The name of the attribute to get. - :return: The requested attribute. - :rtype: Any + :param int idx: Index of the data point(s) to retrieve. + :return: Data point(s) at the specified index. """ - to_return = super().__getattribute__(name) - if isinstance(to_return, (Graph, Data)): - to_return = [to_return] - return to_return diff --git a/pina/_src/condition/data_condition.py b/pina/_src/condition/data_condition.py index ec6da762c..f37b3dc31 100644 --- a/pina/_src/condition/data_condition.py +++ b/pina/_src/condition/data_condition.py @@ -2,12 +2,13 @@ import torch from torch_geometric.data import Data -from pina._src.condition.condition_interface import ConditionInterface +from pina._src.condition.condition_base import ConditionBase from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph +from pina._src.condition.data_manager import _DataManager -class DataCondition(ConditionInterface): +class DataCondition(ConditionBase): """ The class :class:`DataCondition` defines an unsupervised condition based on ``input`` data. This condition is typically used in data-driven problems, @@ -16,17 +17,6 @@ class DataCondition(ConditionInterface): the provided data during training. Optional ``conditional_variables`` can be specified when the model depends on additional parameters. - The class automatically selects the appropriate implementation based on the - type of the ``input`` data. Depending on whether the ``input`` is a tensor - or graph-based data, one of the following specialized subclasses is - instantiated: - - - :class:`TensorDataCondition`: For cases where the ``input`` is either a - :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. - - - :class:`GraphDataCondition`: For cases where the ``input`` is either a - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` object. - :Example: >>> from pina import Condition, LabelTensor @@ -38,14 +28,14 @@ class DataCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "conditional_variables"] + __fields__ = ["input", "conditional_variables"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) def __new__(cls, input, conditional_variables=None): """ - Instantiate the appropriate subclass of :class:`DataCondition` based on - the type of the ``input``. + Check the types of ``input`` and ``conditional_variables`` and + instantiate a class of :class:`DataCondition` accordingly. :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | @@ -63,58 +53,71 @@ def __new__(cls, input, conditional_variables=None): if cls != DataCondition: return super().__new__(cls) - # If the input is a tensor - if isinstance(input, (torch.Tensor, LabelTensor)): - subclass = TensorDataCondition - return subclass.__new__(subclass, input, conditional_variables) - - # If the input is a graph - if isinstance(input, (Graph, Data, list, tuple)): - cls._check_graph_list_consistency(input) - subclass = GraphDataCondition - return subclass.__new__(subclass, input, conditional_variables) - - # If the input is not of the correct type raise an error - raise ValueError( - "Invalid input type. Expected one of the following: " - "torch.Tensor, LabelTensor, Graph, Data or " - "an iterable of the previous types." - ) - - def __init__(self, input, conditional_variables=None): + # Check input type + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "Invalid input type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "an iterable of the previous types." + ) + if isinstance(input, (list, tuple)): + for item in input: + if not isinstance(item, (Data, Graph)): + raise ValueError( + "if input is a list or tuple, all its elements must" + " be of type Graph or Data." + ) + + # Check conditional_variables type + if conditional_variables is not None: + if not isinstance( + conditional_variables, cls._avail_conditional_variables_cls + ): + raise ValueError( + "Invalid conditional_variables type. Expected one of the " + "following: torch.Tensor, LabelTensor." + ) + + return super().__new__(cls) + + def store_data(self, **kwargs): """ - Initialization of the :class:`DataCondition` class. + Store the input data and conditional variables in a dictionary. :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param conditional_variables: The conditional variables for the - condition. Default is ``None``. + condition. :type conditional_variables: torch.Tensor | LabelTensor - - .. note:: - - If ``input`` is a list of :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data`, all elements in - the list must share the same structure, with matching keys and - consistent data types. + :return: A dictionary containing the stored data. + :rtype: dict """ - super().__init__() - self.input = input - self.conditional_variables = conditional_variables - + data_dict = {"input": kwargs.get("input")} + cond_vars = kwargs.get("conditional_variables", None) + if cond_vars is not None: + data_dict["conditional_variables"] = cond_vars + return _DataManager(**data_dict) + + @property + def conditional_variables(self): + """ + Return the conditional variables for the condition. -class TensorDataCondition(DataCondition): - """ - Specialization of the :class:`DataCondition` class for the case where - ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a - :class:`torch.Tensor` object. - """ + :return: The conditional variables. + :rtype: torch.Tensor | LabelTensor | None + """ + if hasattr(self.data, "conditional_variables"): + return self.data.conditional_variables + return None + @property + def input(self): + """ + Return the input data for the condition. -class GraphDataCondition(DataCondition): - """ - Specialization of the :class:`DataCondition` class for the case where - ``input`` is either a :class:`~pina.graph.Graph` object or a - :class:`~torch_geometric.data.Data` object. - """ + :return: The input data. + :rtype: torch.Tensor | LabelTensor | Graph | Data | + list[Graph] | list[Data] | tuple[Graph] | tuple[Data] + """ + return self.data.input diff --git a/pina/_src/condition/data_manager.py b/pina/_src/condition/data_manager.py new file mode 100644 index 000000000..b390cb580 --- /dev/null +++ b/pina/_src/condition/data_manager.py @@ -0,0 +1,348 @@ +""" +Module for managing data in conditions. +""" + +import torch +from torch_geometric.data import Data +from torch_geometric.data.batch import Batch +from pina import LabelTensor +from pina._src.core.graph import Graph, LabelBatch +from ..equation.equation_interface import EquationInterface +from .batch_manager import _BatchManager + + +class _DataManager: + """ + Abstract base class for data managers. + + This class dynamically selects between :class:`_TensorDataManager` and + :class:`_GraphDataManager` based on the types of the input data. + """ + + def __new__(cls, **kwargs): + """ + Dynamically instantiate the appropriate subclass based on the types + of the input data. + - If all values in ``kwargs`` are instances of + :class:`torch.Tensor`, :class:`LabelTensor` then + :class:`_TensorDataManager` is instantiated. + - Otherwise, :class:`_GraphDataManager` is instantiated. + + :param dict kwargs: The keyword arguments containing the data. + :return: An instance of :class:`_TensorDataManager` or + :class:`_GraphDataManager`. + :rtype: _TensorDataManager | _GraphDataManager + """ + # If not called directly, proceed with normal instantiation + if cls is not _DataManager: + return super().__new__(cls) + + # Does the data contain only tensors/LabelTensors/Equations? + is_tensor_only = all( + isinstance(v, (torch.Tensor, LabelTensor, EquationInterface)) + for v in kwargs.values() + ) + # Choose the appropriate subclass, GraphDataManager or TensorDataManager + subclass = _TensorDataManager if is_tensor_only else _GraphDataManager + return super().__new__(subclass) + + def __init__(self, **kwargs): + """ + Initialize the data manager with the provided keyword arguments. + + :param dict kwargs: The keyword arguments containing the data. + """ + self.keys = list(kwargs.keys()) + + +class _TensorDataManager(_DataManager): + """ + Data manager for tensor data. Handles data stored as `torch.Tensor` or + `LabelTensor`. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data = kwargs + + for k, v in kwargs.items(): + setattr(self, k, v) + + def __len__(self): + """ + Return the number of samples in the tensor data manager. + + :return: Number of samples. + :rtype: int + """ + return self.data[self.keys[0]].shape[0] + + def __getitem__(self, idx): + """ + Return a data item or a subset of data items by index. + + :param idx: Index or indices of the data items to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_TensorDataManager` instance containing the + selected data items. + :rtype: _TensorDataManager + """ + # Mapping efficiente degli elementi + new_data = { + k: (self.data[k][idx] if k in self.keys else self.data[k]) + for k in self.keys + } + return _TensorDataManager(**new_data) + + @staticmethod + def create_batch(items): + """ + Create a batch from a list of :class:`_TensorDataManager` items. + + :param list items: List of :class:`_TensorDataManager` items to batch. + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ + if not items: + return None + first = items[0] + batch_data = _BatchManager() + + for k in first.keys: + vals = [it.data[k] for it in items] + sample = vals[0] + + if isinstance(sample, (torch.Tensor, LabelTensor)): + batch_fn = ( + LabelTensor.stack + if isinstance(sample, LabelTensor) + else torch.stack + ) + batch_data[k] = batch_fn(vals, dim=0) + else: + batch_data[k] = sample + return batch_data + + def to_batch(self): + """ + Create a batch from the current tensor data manager. + + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ + batch_data = _BatchManager() + for k in self.keys: + batch_data[k] = self.data[k] + return batch_data + + +class _GraphDataManager(_DataManager): + """ + Data manager for graph data. Handles data stored as :class:`Graph`, + :class:`Data`, or lists/tuples of these types. Moreover , it can also manage + associated tensors stored as :class:`torch.Tensor` or :class:`LabelTensor`. + """ + + def __init__(self, **kwargs): + """ + Initialize the graph data manager with the provided keyword arguments. + + :param dict kwargs: The keyword arguments containing the data. + """ + super().__init__(**kwargs) + self.graph_key = next( + k + for k, v in kwargs.items() + if isinstance(v, (Graph, Data, list, tuple)) + ) + + self.keys = [ + k + for k in self.keys + if k != self.graph_key + and isinstance(kwargs[k], (torch.Tensor, LabelTensor)) + ] + + # Prepare graphs and assign tensors + self.data = self._prepare_graphs(kwargs) + + def _prepare_graphs(self, kwargs): + """ + Store tensors in the corresponding graphs. + + :param dict kwargs: The keyword arguments containing the graphs and + associated tensors. + :return: A list of graphs with tensors assigned. + :rtype: list[Graph] | list[Data] + """ + graphs = kwargs.pop(self.graph_key) + if not isinstance(graphs, (list, tuple)): + graphs = [graphs] + + n_graphs = len(graphs) + for name, tensor in kwargs.items(): + # Verify consistency between number of graphs and tensor samples + if n_graphs != tensor.shape[0]: + raise ValueError( + f"Number of graphs ({n_graphs}) does not match " + f"number of samples for key '{name}' " + f"({kwargs[name].shape[0]})." + ) + # Assign tensors to graphs + for i, g in enumerate(graphs): + setattr(g, name, tensor[i]) + + return graphs + + def __len__(self): + """ + Return the number of graphs in the graph data manager. + + :return: Number of graphs. + :rtype: int + """ + return len(self.data) + + def __getattr__(self, name): + """ + Override attribute access to retrieve tensors or graphs. If the graph + key is requested, return the list of graphs. If a tensor key is + requested, stack the tensors from all graphs and return the result. + + :param str name: The name of the attribute to retrieve. + :return: The requested tensor or graph. + :rtype: torch.Tensor | LabelTensor | Graph | list[Graph] | Data | + """ + # If the requested attribute is a tensor key, stack the tensors from + # all graphs + if name in self.keys: + tensors = [getattr(g, name) for g in self.data] + batch_fn = ( + LabelTensor.stack + if isinstance(tensors[0], LabelTensor) + else torch.stack + ) + return batch_fn(tensors) + + # If the requested attribute is the graph key, return the graphs + if name == self.graph_key: + return self.data if len(self.data) > 1 else self.data[0] + + return super().__getattribute__(name) + + @classmethod + def _init_from_graphs_list(cls, graphs, graph_key, keys): + """ + Initialize a :class:`_GraphDataManager` instance from a list of graphs. + This is used internally to create subsets of the data manager, without + going through the full initialization process. + + :param list graphs: List of graphs to initialize the data manager with. + :param str graph_key: Key under which the graphs are stored. + :param list keys: List of tensor keys associated with the graphs. + :return: A new :class:`_GraphDataManager` instance. + :rtype: _GraphDataManager + """ + # Create a new instance without calling __init__ + obj = _GraphDataManager.__new__(_GraphDataManager) + obj.graph_key = graph_key + obj.keys = keys + obj.data = graphs + return obj + + def __getitem__(self, idx): + """ + Retrieve a graph or a subset of graphs by index. + + :param idx: Index or indices of the graphs to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_GraphDataManager` instance containing the + selected graphs. + :rtype: _GraphDataManager + """ + # Manage int and slice directly + if isinstance(idx, (int, slice)): + selected = self.data[idx] + # Manage list or tensor of indices + elif isinstance(idx, (list, torch.Tensor)): + selected = [self.data[i] for i in idx] + else: + raise TypeError(f"Invalid index type: {type(idx)}") + + # Ensure selected is a list + if not isinstance(selected, list): + selected = [selected] + + # Return a new _GraphDataManager instance with the selected graphs + return _GraphDataManager._init_from_graphs_list( + selected, + # tensor_keys=self._tensor_keys, + graph_key=self.graph_key, + keys=self.keys, + ) + + def to_batch(self): + """ + Create a batch from the current graph data manager. + + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ + batching_fn = ( + LabelBatch.from_data_list + if isinstance(self.data[0], Graph) + else Batch.from_data_list + ) + + batched_graph = batching_fn(self.data) + batch_data = _BatchManager() + for k in self.keys: + if k == self.graph_key: + continue + batch_data[k] = getattr(batched_graph, k) + delattr(batched_graph, k) + batch_data[self.graph_key] = batched_graph + return batch_data + + @staticmethod + def create_batch(items): + """ + Optimized batch creation. + """ + if not items: + return None + + first = items[0] + graph_key = first.graph_key + # Determine batching function once + is_labeled = isinstance(first.data[0], Graph) + batching_fn = ( + LabelBatch.from_data_list if is_labeled else Batch.from_data_list + ) + + # Efficient list comprehension for extraction + # If to_batch() is called on self, self.data might be a list already. + # If _create_batch is called on multiple managers, we grab the first + # graph from each. + graphs_to_batch = [item.data[0] for item in items] + batched_graph = batching_fn(graphs_to_batch) + + batch_data = _BatchManager() + + # Use a set for O(1) lookups if keys is large + keys_to_transfer = set(first.keys) + if graph_key in keys_to_transfer: + keys_to_transfer.remove(graph_key) + + for k in keys_to_transfer: + # Check if attribute exists once to avoid AttributeError overhead + val = getattr(batched_graph, k, None) + if val is not None: + batch_data[k] = val + delattr(batched_graph, k) + + batch_data[graph_key] = batched_graph + return batch_data diff --git a/pina/_src/condition/domain_equation_condition.py b/pina/_src/condition/domain_equation_condition.py index 0b75269ce..08095bbcd 100644 --- a/pina/_src/condition/domain_equation_condition.py +++ b/pina/_src/condition/domain_equation_condition.py @@ -1,12 +1,11 @@ """Module for the DomainEquationCondition class.""" -from pina._src.condition.condition_interface import ConditionInterface -from pina._src.core.utils import check_consistency +from pina._src.condition.condition_base import ConditionBase from pina._src.domain.domain_interface import DomainInterface from pina._src.equation.equation_interface import EquationInterface -class DomainEquationCondition(ConditionInterface): +class DomainEquationCondition(ConditionBase): """ The class :class:`DomainEquationCondition` defines a condition based on a ``domain`` and an ``equation``. This condition is typically used in @@ -30,35 +29,67 @@ class DomainEquationCondition(ConditionInterface): """ # Available slots - __slots__ = ["domain", "equation"] + __fields__ = ["domain", "equation"] - def __init__(self, domain, equation): + _avail_domain_cls = (DomainInterface, str) + _avail_equation_cls = EquationInterface + + def __new__(cls, domain, equation): """ - Initialization of the :class:`DomainEquationCondition` class. + Check the types of ``domain`` and ``equation`` and instantiate an + instance of :class:`DomainEquationCondition`. - :param DomainInterface domain: The domain over which the equation is - defined. - :param EquationInterface equation: The equation to be satisfied over the - specified domain. + :return: An instance of :class:`DomainEquationCondition`. + :rtype: pina.condition.domain_equation_condition.DomainEquationCondition + :raises ValueError: If ``domain`` is not of type + :class:`DomainInterface` or + ``equation`` is not of type :class:` """ - super().__init__() - self.domain = domain - self.equation = equation + if not isinstance(domain, cls._avail_domain_cls): + raise ValueError( + "The domain must be an instance of DomainInterface." + ) + + if not isinstance(equation, cls._avail_equation_cls): + raise ValueError( + "The equation must be an instance of EquationInterface." + ) + + return super().__new__(cls) - def __setattr__(self, key, value): + def __len__(self): """ - Set the attribute value with type checking. + Raise NotImplementedError since the number of points is determined by + the domain sampling strategy. - :param str key: The attribute name. - :param any value: The value to set for the attribute. + :raises NotImplementedError: Always raised since the number of points is + determined by the domain sampling strategy. """ - if key == "domain": - check_consistency(value, (DomainInterface, str)) - DomainEquationCondition.__dict__[key].__set__(self, value) + raise NotImplementedError( + "`__len__` method is not implemented for " + "`DomainEquationCondition` since the number of points is " + "determined by the domain sampling strategy." + ) - elif key == "equation": - check_consistency(value, (EquationInterface)) - DomainEquationCondition.__dict__[key].__set__(self, value) + def __getitem__(self, idx): + """ + Raise NotImplementedError since data retrieval is not applicable. - elif key in ("_problem"): - super().__setattr__(key, value) + :param int idx: Index of the data point(s) to retrieve. + :raises NotImplementedError: Always raised since data retrieval is not + applicable for this condition. + """ + raise NotImplementedError( + "`__getitem__` method is not implemented for " + "`DomainEquationCondition`" + ) + + def store_data(self, **kwargs): + """ + Store data for the condition. No data is stored for this condition. + + :return: An empty dictionary since no data is stored. + :rtype: dict + """ + setattr(self, "domain", kwargs.get("domain")) + setattr(self, "equation", kwargs.get("equation")) diff --git a/pina/_src/condition/input_equation_condition.py b/pina/_src/condition/input_equation_condition.py index 636d8b9f8..62dac3a30 100644 --- a/pina/_src/condition/input_equation_condition.py +++ b/pina/_src/condition/input_equation_condition.py @@ -1,13 +1,13 @@ """Module for the InputEquationCondition class and its subclasses.""" -from pina._src.condition.condition_interface import ConditionInterface +from pina._src.condition.condition_base import ConditionBase from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph -from pina._src.core.utils import check_consistency from pina._src.equation.equation_interface import EquationInterface +from pina._src.condition.data_manager import _DataManager -class InputEquationCondition(ConditionInterface): +class InputEquationCondition(ConditionBase): """ The class :class:`InputEquationCondition` defines a condition based on ``input`` data and an ``equation``. This condition is typically used in @@ -15,17 +15,6 @@ class InputEquationCondition(ConditionInterface): ``equation`` through the evaluation of the residual performed at the provided ``input``. - The class automatically selects the appropriate implementation based on - the type of the ``input`` data. Depending on whether the ``input`` is a - tensor or graph-based data, one of the following specialized subclasses is - instantiated: - - - :class:`InputTensorEquationCondition`: For cases where the ``input`` - data is a :class:`~pina.label_tensor.LabelTensor` object. - - - :class:`InputGraphEquationCondition`: For cases where the ``input`` data - is a :class:`~pina.graph.Graph` object. - :Example: >>> from pina import Condition, LabelTensor @@ -41,14 +30,14 @@ class InputEquationCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "equation"] - _avail_input_cls = (LabelTensor, Graph, list, tuple) + __fields__ = ["input", "equation"] + _avail_input_cls = (LabelTensor, Graph) _avail_equation_cls = EquationInterface def __new__(cls, input, equation): """ - Instantiate the appropriate subclass of :class:`InputEquationCondition` - based on the type of ``input`` data. + Check the types of ``input`` and ``equation`` and instantiate a class + of :class:`InputEquationCondition` accordingly. :param input: The input data for the condition. :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] @@ -62,96 +51,59 @@ def __new__(cls, input, equation): :raises ValueError: If input is not of type :class:`~pina.graph.Graph` or :class:`~pina.label_tensor.LabelTensor`. """ - if cls != InputEquationCondition: - return super().__new__(cls) - - # If the input is a Graph object - if isinstance(input, (Graph, list, tuple)): - subclass = InputGraphEquationCondition - cls._check_graph_list_consistency(input) - subclass._check_label_tensor(input) - return subclass.__new__(subclass, input, equation) - - # If the input is a LabelTensor - if isinstance(input, LabelTensor): - subclass = InputTensorEquationCondition - return subclass.__new__(subclass, input, equation) - - # If the input is not a LabelTensor or a Graph object raise an error - raise ValueError( - "The input data object must be a LabelTensor or a Graph object." - ) - - def __init__(self, input, equation): - """ - Initialization of the :class:`InputEquationCondition` class. - :param input: The input data for the condition. - :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] - :param EquationInterface equation: The equation to be satisfied over the - specified input points. + # CHeck input type + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "The input data object must be a LabelTensor or a Graph object." + ) - .. note:: + # Check equation type + if not isinstance(equation, cls._avail_equation_cls): + raise ValueError( + "The equation must be an instance of EquationInterface." + ) - If ``input`` is a list of :class:`~pina.graph.Graph` all elements in - the list must share the same structure, with matching keys and - consistent data types. - """ - super().__init__() - self.input = input - self.equation = equation + return super().__new__(cls) - def __setattr__(self, key, value): + def store_data(self, **kwargs): """ - Set the attribute value with type checking. - - :param str key: The attribute name. - :param any value: The value to set for the attribute. + Store the input data in a :class:`_DataManager` object. + :param dict kwargs: The keyword arguments containing the input data. """ - if key == "input": - check_consistency(value, self._avail_input_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, self._avail_equation_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key in ("_problem"): - super().__setattr__(key, value) - + setattr(self, "equation", kwargs.pop("equation")) + return _DataManager(**kwargs) -class InputTensorEquationCondition(InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. - """ - - -class InputGraphEquationCondition(InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.graph.Graph` object. - """ + @property + def input(self): + """ + Return the input data for the condition. - @staticmethod - def _check_label_tensor(input): + :return: The input data. + :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] """ - Check if at least one :class:`~pina.label_tensor.LabelTensor` is present - in the ``input`` object. + return self.data.input - :param input: The input data. - :type input: torch.Tensor | Graph | list[Graph] | tuple[Graph] - :raises ValueError: If the input data object does not contain at least - one LabelTensor. + @property + def equation(self): """ + Return the equation associated with this condition. - # Store the first element: it is sufficient to check this since all - # elements must have the same type and structure (already checked). - data = input[0] if isinstance(input, (list, tuple)) else input + :return: Equation associated with this condition. + :rtype: EquationInterface + """ + return self._equation - # Check if the input data contains at least one LabelTensor - for v in data.values(): - if isinstance(v, LabelTensor): - return + @equation.setter + def equation(self, value): + """ + Set the equation associated with this condition. - raise ValueError("The input must contain at least one LabelTensor.") + :param EquationInterface value: The equation to associate with this + condition + """ + if not isinstance(value, EquationInterface): + raise TypeError( + "The equation must be an instance of EquationInterface." + ) + self._equation = value diff --git a/pina/_src/condition/input_target_condition.py b/pina/_src/condition/input_target_condition.py index e1392ed75..dd81cd252 100644 --- a/pina/_src/condition/input_target_condition.py +++ b/pina/_src/condition/input_target_condition.py @@ -6,10 +6,11 @@ from torch_geometric.data import Data from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph -from pina._src.condition.condition_interface import ConditionInterface +from pina._src.condition.condition_base import ConditionBase +from pina._src.condition.data_manager import _DataManager -class InputTargetCondition(ConditionInterface): +class InputTargetCondition(ConditionBase): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -17,29 +18,6 @@ class InputTargetCondition(ConditionInterface): include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - The class automatically selects the appropriate implementation based on - the types of ``input`` and ``target``. Depending on whether the ``input`` - and ``target`` are tensors or graph-based data, one of the following - specialized subclasses is instantiated: - - - :class:`TensorInputTensorTargetCondition`: For cases where both ``input`` - and ``target`` data are either :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor`. - - - :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is - either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor` - and ``target`` is either a :class:`~pina.graph.Graph` or a - :class:`torch_geometric.data.Data`. - - - :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is - either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` - and ``target`` is either a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - - - :class:`GraphInputGraphTargetCondition`: For cases where both ``input`` - and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data`. - :Example: >>> from pina import Condition, LabelTensor @@ -55,154 +33,82 @@ class InputTargetCondition(ConditionInterface): """ # Available input and target data types - __slots__ = ["input", "target"] + __fields__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) def __new__(cls, input, target): """ - Instantiate the appropriate subclass of :class:`InputTargetCondition` - based on the types of both ``input`` and ``target`` data. + Check the types of ``input`` and ``target`` data and instantiate the + :class:`InputTargetCondition`. :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :return: The subclass of InputTargetCondition. - :rtype: pina.condition.input_target_condition. - TensorInputTensorTargetCondition | - pina.condition.input_target_condition. - TensorInputGraphTargetCondition | - pina.condition.input_target_condition. - GraphInputTensorTargetCondition | - pina.condition.input_target_condition.GraphInputGraphTargetCondition - - :raises ValueError: If ``input`` and/or ``target`` are not of type - :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, - :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - """ - if cls != InputTargetCondition: - return super().__new__(cls) - - # Tensor - Tensor - if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( - target, (torch.Tensor, LabelTensor) - ): - subclass = TensorInputTensorTargetCondition - return subclass.__new__(subclass, input, target) - - # Tensor - Graph - if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(target) - subclass = TensorInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # Graph - Tensor - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (torch.Tensor, LabelTensor) - ): - cls._check_graph_list_consistency(input) - subclass = GraphInputTensorTargetCondition - return subclass.__new__(subclass, input, target) - - # Graph - Graph - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(input) - cls._check_graph_list_consistency(target) - subclass = GraphInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # If the input and/or target are not of the correct type raise an error - raise ValueError( - "Invalid input | target types." - "Please provide either torch_geometric.data.Data, Graph, " - "LabelTensor or torch.Tensor objects." - ) - - def __init__(self, input, target): + :type target: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] + :return: An instance of :class:`InputTargetCondition`. + :rtype: pina.condition.input_target_condition.InputTargetCondition + :raises ValueError: If ``input`` or ``target`` are not of supported types. """ - Initialization of the :class:`InputTargetCondition` class. - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "Invalid input type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "list/tuple of Graph/Data objects." + ) + if isinstance(input, (list, tuple)): + for item in input: + if not isinstance(item, (Graph, Data)): + raise ValueError( + "If target is a list or tuple, all its elements " + "must be of type Graph or Data." + ) + + if not isinstance(target, cls._avail_output_cls): + raise ValueError( + "Invalid target type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "list/tuple of Graph/Data objects." + ) + if isinstance(target, (list, tuple)): + for item in target: + if not isinstance(item, (Graph, Data)): + raise ValueError( + "If target is a list or tuple, all its elements " + "must be of type Graph or Data." + ) - .. note:: + return super().__new__(cls) - If either ``input`` or ``target`` is a list of - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` - objects, all elements in the list must share the same structure, - with matching keys and consistent data types. + def store_data(self, **kwargs): + """ + Store the input and target data in a :class:`_DataManager` object. + :param dict kwargs: The keyword arguments containing the input and + target data. """ - super().__init__() - self._check_input_target_len(input, target) - self.input = input - self.target = target + return _DataManager(**kwargs) - @staticmethod - def _check_input_target_len(input, target): + @property + def input(self): """ - Check that the length of the input and target lists are the same. + Return the input data for the condition. - :param input: The input data. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + :return: The input data. + :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :raises ValueError: If the lengths of the input and target lists do not - match. """ - if isinstance(input, (Graph, Data)) or isinstance( - target, (Graph, Data) - ): - return - - # Raise an error if the lengths of the input and target do not match - if len(input) != len(target): - raise ValueError( - "The input and target lists must have the same length." - ) - - -class TensorInputTensorTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` objects. - """ - - -class TensorInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - ``input`` is either a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a - :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. - """ - - -class GraphInputTensorTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - ``input`` is either a :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` object and ``target`` is either a - :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. - """ + return self.data.input + @property + def target(self): + """ + Return the target data for the condition. -class GraphInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` objects. - """ + :return: The target data. + :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] + """ + return self.data.target diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index 381186e00..cfaeb5bec 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -61,10 +61,15 @@ def collected_data(self): if not self.are_all_domains_discretised: warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join([f"""{" " * 13} ---> Domain {key} { + warning_message = "\n".join( + [ + f"""{" " * 13} ---> Domain {key} { "sampled" if key in self.discretised_domains else - "not sampled"}""" for key in self.domains]) + "not sampled"}""" + for key in self.domains + ] + ) warnings.warn( "Some of the domains are still not sampled. Consider calling " "problem.discretise_domain function for all domains before " diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 696567fa8..0cdf7a977 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -9,39 +9,19 @@ __all__ = [ "Condition", "ConditionInterface", + "ConditionBase", "DomainEquationCondition", "InputTargetCondition", - "TensorInputTensorTargetCondition", - "TensorInputGraphTargetCondition", - "GraphInputTensorTargetCondition", - "GraphInputGraphTargetCondition", "InputEquationCondition", - "InputTensorEquationCondition", - "InputGraphEquationCondition", "DataCondition", - "GraphDataCondition", - "TensorDataCondition", ] from pina._src.condition.condition_interface import ConditionInterface +from pina._src.condition.condition_base import ConditionBase from pina._src.condition.condition import Condition from pina._src.condition.domain_equation_condition import ( DomainEquationCondition, ) -from pina._src.condition.input_target_condition import ( - InputTargetCondition, - TensorInputTensorTargetCondition, - TensorInputGraphTargetCondition, - GraphInputTensorTargetCondition, - GraphInputGraphTargetCondition, -) -from pina._src.condition.input_equation_condition import ( - InputEquationCondition, - InputTensorEquationCondition, - InputGraphEquationCondition, -) -from pina._src.condition.data_condition import ( - DataCondition, - GraphDataCondition, - TensorDataCondition, -) +from pina._src.condition.input_target_condition import InputTargetCondition +from pina._src.condition.input_equation_condition import InputEquationCondition +from pina._src.condition.data_condition import DataCondition diff --git a/tests/test_condition.py b/tests/test_condition.py deleted file mode 100644 index 266233179..000000000 --- a/tests/test_condition.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import pytest - -from pina import LabelTensor, Condition -from pina.condition import ( - TensorInputGraphTargetCondition, - TensorInputTensorTargetCondition, - GraphInputGraphTargetCondition, - GraphInputTensorTargetCondition, -) -from pina.condition import ( - InputTensorEquationCondition, - InputGraphEquationCondition, - DomainEquationCondition, -) -from pina.condition import ( - TensorDataCondition, - GraphDataCondition, -) -from pina.domain import CartesianDomain -from pina.equation import FixedValue -from pina.graph import RadiusGraph - -example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - -input_tensor = torch.rand((10, 3)) -target_tensor = torch.rand((10, 2)) -input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) -target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) - -x = torch.rand(10, 20, 2) -pos = torch.rand(10, 20, 2) -radius = 0.1 -input_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] -target_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] - -x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) -pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) -radius = 0.1 -input_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] -target_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] - -input_single_graph = input_graph[0] -target_single_graph = target_graph[0] - - -def test_init_input_target(): - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_graph, target=target_tensor) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - cond = Condition(input=input_lt, target=input_single_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_lt) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_single_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - with pytest.raises(ValueError): - Condition(input_tensor, input_tensor) - with pytest.raises(ValueError): - Condition(input=3.0, target="example") - with pytest.raises(ValueError): - Condition(input=example_domain, target=example_domain) - - # Test wrong graph condition initialisation - input = [input_graph[0], input_graph_lt[0]] - target = [target_graph[0], target_graph_lt[0]] - with pytest.raises(ValueError): - Condition(input=input, target=target) - - input_graph_lt[0].x.labels = ["a", "b"] - with pytest.raises(ValueError): - Condition(input=input_graph_lt, target=target_graph_lt) - input_graph_lt[0].x.labels = ["u", "v"] - - -def test_init_domain_equation(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) - assert isinstance(cond, DomainEquationCondition) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(domain=3.0, equation="example") - with pytest.raises(ValueError): - Condition(domain=input_tensor, equation=input_graph) - - -def test_init_input_equation(): - cond = Condition(input=input_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputTensorEquationCondition) - cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputGraphEquationCondition) - with pytest.raises(ValueError): - cond = Condition(input=input_tensor, equation=FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(input=3.0, equation="example") - with pytest.raises(ValueError): - Condition(input=example_domain, equation=input_graph) - - -test_init_input_equation() - - -def test_init_data_condition(): - cond = Condition(input=input_lt) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_graph) - assert isinstance(cond, GraphDataCondition) - cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) - assert isinstance(cond, GraphDataCondition) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py new file mode 100644 index 000000000..4a88f963c --- /dev/null +++ b/tests/test_condition/test_data_condition.py @@ -0,0 +1,332 @@ +import pytest +import torch +from pina import Condition, LabelTensor +from pina.condition import DataCondition +from pina.graph import RadiusGraph +from torch_geometric.data import Data +from pina._src.condition.data_manager import _DataManager + + +def _create_tensor_data(use_lt=False, conditional_variables=False): + input_tensor = torch.rand((10, 3)) + if use_lt: + input_tensor = LabelTensor(input_tensor, ["x", "y", "z"]) + if conditional_variables: + cond_vars = torch.rand((10, 2)) + if use_lt: + cond_vars = LabelTensor(cond_vars, ["a", "b"]) + else: + cond_vars = None + return input_tensor, cond_vars + + +def _create_graph_data(use_lt=False, conditional_variables=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + input_graph = [ + RadiusGraph(pos=pos[i], radius=radius, x=x[i]) for i in range(len(x)) + ] + if conditional_variables: + if use_lt: + cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + cond_vars = torch.rand(10, 20, 1) + else: + cond_vars = None + return input_graph, cond_vars + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_tensor_data_condition_tensor(conditional_variables): + # Setup for standard torch.Tensor + input_tensor, cond_vars = _create_tensor_data( + use_lt=False, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + + assert isinstance(condition, DataCondition) + + # Input assertions + assert isinstance(condition.input, torch.Tensor) + assert not isinstance(condition.input, LabelTensor) + + # Conditional variables assertions + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, torch.Tensor) + assert not isinstance(condition.conditional_variables, LabelTensor) + else: + assert condition.conditional_variables is None + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_tensor_data_condition_label_tensor(conditional_variables): + # Setup for LabelTensor + input_tensor, cond_vars = _create_tensor_data( + use_lt=True, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + + assert isinstance(condition, DataCondition) + + # Input assertions with label validation + assert isinstance(condition.input, LabelTensor) + assert condition.input.labels == ["x", "y", "z"] + + # Conditional variables assertions with label validation + if conditional_variables: + assert isinstance(condition.conditional_variables, LabelTensor) + assert condition.conditional_variables.labels == ["a", "b"] + else: + assert condition.conditional_variables is None + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_graph_data_condition_tensor(conditional_variables): + # Setup for standard torch.Tensor + input_graph, cond_vars = _create_graph_data( + use_lt=False, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + + assert isinstance(condition, DataCondition) + + # Validate Input list + assert isinstance(condition.input, list) + for graph in condition.input: + assert isinstance(graph, Data) + assert isinstance(graph.x, torch.Tensor) + assert not isinstance(graph.x, LabelTensor) + assert isinstance(graph.pos, torch.Tensor) + + # Validate Conditional Variables + if conditional_variables: + assert isinstance(condition.conditional_variables, torch.Tensor) + assert not isinstance(condition.conditional_variables, LabelTensor) + else: + assert condition.conditional_variables is None + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_graph_data_condition_label_tensor(conditional_variables): + # Setup for LabelTensor + input_graph, cond_vars = _create_graph_data( + use_lt=True, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + + assert isinstance(condition, DataCondition) + + # Validate Input list and Labels + for graph in condition.input: + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == ["u", "v"] + + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == ["x", "y"] + + # Validate Conditional Variables and Labels + if conditional_variables: + assert isinstance(condition.conditional_variables, LabelTensor) + assert condition.conditional_variables.labels == ["f"] + else: + assert condition.conditional_variables is None + + +def test_wrong_init_data_condition(): + input_tensor, cond_vars = _create_tensor_data() + # Wrong input type + with pytest.raises(ValueError): + Condition(input="invalid_input", conditional_variables=cond_vars) + # Wrong conditional_variables type + with pytest.raises(ValueError): + Condition(input=input_tensor, conditional_variables="invalid_cond_vars") + # Wrong input type (list with wrong elements) + with pytest.raises(ValueError): + Condition(input=[input_tensor], conditional_variables=cond_vars) + # Wrong conditional_variables type (list) + with pytest.raises(ValueError): + Condition(input=input_tensor, conditional_variables=[cond_vars]) + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_tensor_data_condition_tensor(conditional_variables): + # Setup for standard torch.Tensor + input_tensor, cond_vars = _create_tensor_data( + use_lt=False, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + + item = condition[0] + + # Input assertions + assert isinstance(item.input, torch.Tensor) + assert not isinstance(item.input, LabelTensor) + assert item.input.shape == (3,) + + # Conditional variables assertions + if conditional_variables: + assert isinstance(item.conditional_variables, torch.Tensor) + assert item.conditional_variables.shape == (2,) + else: + assert not hasattr(item, "conditional_variables") + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_tensor_data_condition_label_tensor(conditional_variables): + # Setup for LabelTensor + input_tensor, cond_vars = _create_tensor_data( + use_lt=True, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + + item = condition[0] + + # Input assertions with label validation + assert isinstance(item.input, LabelTensor) + assert item.input.shape == (3,) + assert item.input.labels == ["x", "y", "z"] + + # Conditional variables assertions with label validation + if conditional_variables: + assert isinstance(item.conditional_variables, LabelTensor) + assert item.conditional_variables.shape == (2,) + assert item.conditional_variables.labels == ["a", "b"] + else: + assert not hasattr(item, "conditional_variables") + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_graph_data_condition_tensor(conditional_variables): + # Setup specifically for standard torch.Tensor + input_graph, cond_vars = _create_graph_data( + use_lt=False, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + + item = condition[0] + + # Assertions for the graph data + assert isinstance(item.input, Data) + assert isinstance(item.input.x, torch.Tensor) + assert not isinstance(item.input.x, LabelTensor) + assert item.input.x.shape == (20, 2) + + # Assertions for conditional variables + if conditional_variables: + assert isinstance(item.conditional_variables, torch.Tensor) + assert item.conditional_variables.shape == (1, 20, 1) + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_graph_data_condition_label_tensor(conditional_variables): + # Setup specifically for LabelTensor + input_graph, cond_vars = _create_graph_data( + use_lt=True, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + + item = condition[0] + graph = item.input + + # Assertions for LabelTensor attributes + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == ["u", "v"] + assert graph.x.shape == (20, 2) + + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == ["x", "y"] + + # Assertions for labeled conditional variables + if conditional_variables: + cond_var = item.conditional_variables + assert isinstance(cond_var, LabelTensor) + assert cond_var.labels == ["f"] + assert cond_var.shape == (1, 20, 1) + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, _DataManager) + assert hasattr(items, "input") + type_ = LabelTensor if use_lt else torch.Tensor + inputs = items.input + assert isinstance(inputs, type_) + assert inputs.shape == (3, 3) + if use_lt: + assert inputs.labels == ["x", "y", "z"] + if conditional_variables: + assert hasattr(items, "conditional_variables") + cond_vars_items = items.conditional_variables + assert isinstance(cond_vars_items, type_) + assert cond_vars_items.shape == (3, 2) + if use_lt: + assert cond_vars_items.labels == ["a", "b"] + else: + assert not hasattr(items, "conditional_variables") + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition_tensor(conditional_variables): + # Setup with use_lt=False + input_graph, cond_vars = _create_graph_data( + use_lt=False, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + + idxs = [0, 1, 3] + items = condition[idxs] + + # Assertions for DataManager and Graphs + assert isinstance(items, _DataManager) + graphs = items.input + assert len(graphs) == 3 + + for graph in graphs: + assert isinstance(graph.x, torch.Tensor) + assert not isinstance(graph.x, LabelTensor) + assert graph.x.shape == (20, 2) + + # Assertions for Conditional Variables + if conditional_variables: + assert isinstance(items.conditional_variables, torch.Tensor) + assert items.conditional_variables.shape == (3, 20, 1) + + +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition_label_tensor(conditional_variables): + # Setup with use_lt=True + input_graph, cond_vars = _create_graph_data( + use_lt=True, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + + idxs = [0, 1, 3] + items = condition[idxs] + + # Assertions for LabelTensor specific attributes in Graphs + for graph in items.input: + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == ["u", "v"] + + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == ["x", "y"] + + # Assertions for LabelTensor in Conditional Variables + if conditional_variables: + cv = items.conditional_variables + assert isinstance(cv, LabelTensor) + assert cv.labels == ["f"] + assert cv.shape == (3, 20, 1) diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py new file mode 100644 index 000000000..46bc89bc3 --- /dev/null +++ b/tests/test_condition/test_domain_equation_condition.py @@ -0,0 +1,29 @@ +import pytest +from pina import Condition +from pina.domain import CartesianDomain +from pina._src.equation.equation_factory import FixedValue +from pina.condition import DomainEquationCondition + +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +example_equation = FixedValue(0.0) + + +def test_init_domain_equation(): + cond = Condition(domain=example_domain, equation=example_equation) + assert isinstance(cond, DomainEquationCondition) + assert cond.domain is example_domain + assert cond.equation is example_equation + assert hasattr(cond, "data") + assert cond.data is None + + +def test_len_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + len(cond) + + +def test_getitem_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + cond[0] diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py new file mode 100644 index 000000000..4bed448b5 --- /dev/null +++ b/tests/test_condition/test_input_equation_condition.py @@ -0,0 +1,79 @@ +import torch +import pytest +from pina import Condition +from pina._src.condition.input_equation_condition import InputEquationCondition +from pina.equation import Equation +from pina import LabelTensor +from pina.graph import Graph +from pina._src.condition.data_manager import _DataManager + + +def _create_pts_and_equation(): + def dummy_equation(pts): + return pts["x"] ** 2 + pts["y"] ** 2 - 1 + + pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + equation = Equation(dummy_equation) + return pts, equation + + +def _create_graph_and_equation(): + from pina.graph import KNNGraph + + def dummy_equation(pts): + return pts.x[:, 0] ** 2 + pts.x[:, 1] ** 2 - 1 + + x = LabelTensor(torch.randn(100, 2), labels=["u", "v"]) + pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + graph = KNNGraph(x=x, pos=pos, neighbours=5, edge_attr=True) + equation = Equation(dummy_equation) + return graph, equation + + +def test_init_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + assert isinstance(condition, InputEquationCondition) + assert condition.input.shape == (100, 2) + assert condition.equation is equation + + +def test_init_graph_equation_condition(): + graph, equation = _create_graph_and_equation() + condition = Condition(input=graph, equation=equation) + assert isinstance(condition, InputEquationCondition) + assert isinstance(condition.input, Graph) + assert condition.input.x.shape == (100, 2) + assert condition.equation is equation + + +def test_wrong_init_equation_condition(): + pts, equation = _create_pts_and_equation() + # Wrong input type + with pytest.raises(ValueError): + Condition(input=torch.randn(10, 2), equation=equation) + # Wrong equation type + with pytest.raises(ValueError): + Condition(input=pts, equation="not_an_equation") + # Wrong input type (list with wrong elements) + with pytest.raises(ValueError): + Condition(input=[torch.randn(10, 2)], equation=equation) + + +def test_getitem_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + item = condition[0] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + assert item.input.shape == (2,) + + +def test_getitems_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + idxs = [0, 1, 3] + item = condition[idxs] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + assert item.input.shape == (3, 2) diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py new file mode 100644 index 000000000..1f469f0cd --- /dev/null +++ b/tests/test_condition/test_input_target_condition.py @@ -0,0 +1,409 @@ +import torch +import pytest +from pina import LabelTensor, Condition +from pina.graph import RadiusGraph +from pina._src.condition.batch_manager import _BatchManager + + +def _create_tensor_data(use_lt=False): + if use_lt: + input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) + target_tensor = LabelTensor(torch.rand((10, 2)), ["a", "b"]) + return input_tensor, target_tensor + input_tensor = torch.rand((10, 3)) + target_tensor = torch.rand((10, 2)) + return input_tensor, target_tensor + + +def _create_graph_data(tensor_input=True, use_lt=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + graph = [ + RadiusGraph( + pos=pos[i], + radius=radius, + x=x[i] if not tensor_input else None, + y=x[i] if tensor_input else None, + ) + for i in range(len(x)) + ] + if use_lt: + tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + tensor = torch.rand(10, 20, 1) + return graph, tensor + + +def test_init_tensor_input_tensor_target_condition_tensor(): + # Setup for standard torch.Tensor + input_tensor, target_tensor = _create_tensor_data(use_lt=False) + condition = Condition(input=input_tensor, target=target_tensor) + + # Numerical assertions + assert torch.allclose( + condition.input, input_tensor + ), "Standard input tensor equality failed" + assert torch.allclose( + condition.target, target_tensor + ), "Standard target tensor equality failed" + + # Type assertions + assert isinstance(condition.input, torch.Tensor) + assert not isinstance(condition.input, LabelTensor) + assert isinstance(condition.target, torch.Tensor) + assert not isinstance(condition.target, LabelTensor) + + +def test_init_tensor_input_tensor_target_condition_label_tensor(): + # Setup for LabelTensor + input_tensor, target_tensor = _create_tensor_data(use_lt=True) + condition = Condition(input=input_tensor, target=target_tensor) + + # Type and Label assertions for Input + assert isinstance( + condition.input, LabelTensor + ), "Input did not preserve LabelTensor type" + assert condition.input.labels == [ + "x", + "y", + "z", + ], "Input labels were lost or corrupted" + + # Type and Label assertions for Target + assert isinstance( + condition.target, LabelTensor + ), "Target did not preserve LabelTensor type" + assert condition.target.labels == [ + "a", + "b", + ], "Target labels were lost or corrupted" + + # Numerical parity check still applies + assert torch.allclose(condition.input, input_tensor) + assert torch.allclose(condition.target, target_tensor) + + +def test_init_tensor_input_graph_target_condition_tensor(): + # Setup for standard torch.Tensor + target_graph, input_tensor = _create_graph_data(use_lt=False) + condition = Condition(input=input_tensor, target=target_graph) + + # Input assertions (Tensor) + assert isinstance(condition.input, torch.Tensor) + assert not isinstance(condition.input, LabelTensor) + assert torch.allclose(condition.input, input_tensor) + + # Target assertions (Graph List) + assert isinstance(condition.target, list) + for i, graph in enumerate(target_graph): + assert isinstance(condition.target[i].y, torch.Tensor) + assert not isinstance(condition.target[i].y, LabelTensor) + assert torch.allclose(condition.target[i].y, graph.y) + + +def test_init_tensor_input_graph_target_condition_label_tensor(): + # Setup for LabelTensor + target_graph, input_tensor = _create_graph_data(use_lt=True) + condition = Condition(input=input_tensor, target=target_graph) + + # Input assertions with label validation + assert isinstance(condition.input, LabelTensor) + assert condition.input.labels == ["f"] + assert torch.allclose(condition.input, input_tensor) + + # Target assertions with nested label validation + for i, graph in enumerate(target_graph): + target_y = condition.target[i].y + assert isinstance(target_y, LabelTensor) + assert target_y.labels == ["u", "v"] + assert torch.allclose(target_y, graph.y) + + +def test_init_graph_input_tensor_target_condition_tensor(): + # Setup for standard torch.Tensor (use_lt=False) + input_graph, target_tensor = _create_graph_data(False, use_lt=False) + condition = Condition(input=input_graph, target=target_tensor) + + # Input assertions: Check graph list integrity + assert isinstance(condition.input, list) + for i, original_graph in enumerate(input_graph): + assert torch.allclose(condition.input[i].x, original_graph.x) + assert isinstance(condition.input[i].x, torch.Tensor) + assert not isinstance(condition.input[i].x, LabelTensor) + + # Target assertions: Check raw tensor integrity + assert torch.allclose(condition.target, target_tensor) + assert isinstance(condition.target, torch.Tensor) + assert not isinstance(condition.target, LabelTensor) + + +def test_init_graph_input_tensor_target_condition_label_tensor(): + # Setup for LabelTensor (use_lt=True) + input_graph, target_tensor = _create_graph_data(False, use_lt=True) + condition = Condition(input=input_graph, target=target_tensor) + + # Input assertions: Check LabelTensor preservation in Graphs + for i, original_graph in enumerate(input_graph): + input_x = condition.input[i].x + assert isinstance(input_x, LabelTensor) + assert input_x.labels == original_graph.x.labels + assert torch.allclose(input_x, original_graph.x) + + # Target assertions: Check LabelTensor preservation in Target + assert isinstance(condition.target, LabelTensor) + assert condition.target.labels == ["f"] + assert torch.allclose(condition.target, target_tensor) + + +def test_wrong_init(): + input_tensor, target_tensor = _create_tensor_data() + with pytest.raises(ValueError): + Condition(input="invalid_input", target=target_tensor) + with pytest.raises(ValueError): + Condition(input=input_tensor, target="invalid_target") + with pytest.raises(ValueError): + Condition(input=[input_tensor], target=target_tensor) + with pytest.raises(ValueError): + Condition(input=input_tensor, target=[target_tensor]) + + +def test_getitem_tensor_input_tensor_target_condition_tensor(): + # Setup for standard torch.Tensor + input_tensor, target_tensor = _create_tensor_data(use_lt=False) + condition = Condition(input=input_tensor, target=target_tensor) + + # We test a single index to verify __getitem__ logic + index = 0 + item = condition[index] + + # Numerical and Type Assertions + assert torch.allclose(item.input, input_tensor[index]) + assert isinstance(item.input, torch.Tensor) + assert not isinstance(item.input, LabelTensor) + + assert torch.allclose(item.target, target_tensor[index]) + assert isinstance(item.target, torch.Tensor) + assert not isinstance(item.target, LabelTensor) + + +def test_getitem_tensor_input_tensor_target_condition_label_tensor(): + # Setup for LabelTensor + input_tensor, target_tensor = _create_tensor_data(use_lt=True) + condition = Condition(input=input_tensor, target=target_tensor) + + index = 0 + item = condition[index] + + # Verify Input LabelTensor preservation + assert isinstance(item.input, LabelTensor) + assert item.input.labels == input_tensor.labels + assert torch.allclose(item.input, input_tensor[index]) + + # Verify Target LabelTensor preservation + assert isinstance(item.target, LabelTensor) + assert item.target.labels == target_tensor.labels + assert torch.allclose(item.target, target_tensor[index]) + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + assert len(condition) == len(input_graph) + for i in range(len(input_graph)): + item = condition[i] + assert torch.allclose( + item.input.x, input_graph[i].x + ), "GraphInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item.target, target_tensor[i] + ), "GraphInputTensorTargetCondition __getitem__ target failed" + if use_lt: + assert isinstance( + item.input.x, LabelTensor + ), "GraphInputTensorTargetCondition __getitem__ input type failed" + assert ( + item.input.x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition __getitem__ input labels failed" + assert isinstance( + item.target, LabelTensor + ), "GraphInputTensorTargetCondition __getitem__ target type failed" + assert item.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition __getitem__ target labels failed" + + +def test_getitem_tensor_input_graph_target_condition_tensor(): + # Setup for standard torch.Tensor + target_graph, input_tensor = _create_graph_data(use_lt=False) + condition = Condition(input=input_tensor, target=target_graph) + + # Check first item indexing + idx = 0 + item = condition[idx] + + # Input assertions (Tensor) + assert torch.allclose(item.input, input_tensor[idx]) + assert isinstance(item.input, torch.Tensor) + assert not isinstance(item.input, LabelTensor) + + # Target assertions (Graph Data) + assert torch.allclose(item.target.y, target_graph[idx].y) + assert isinstance(item.target.y, torch.Tensor) + assert not isinstance(item.target.y, LabelTensor) + + +def test_getitem_tensor_input_graph_target_condition_label_tensor(): + # Setup for LabelTensor + target_graph, input_tensor = _create_graph_data(use_lt=True) + condition = Condition(input=input_tensor, target=target_graph) + + idx = 0 + item = condition[idx] + + # Input LabelTensor validation + assert isinstance(item.input, LabelTensor) + assert item.input.labels == input_tensor.labels + assert torch.allclose(item.input, input_tensor[idx]) + + # Target Graph LabelTensor validation + target_y = item.target.y + assert isinstance(target_y, LabelTensor) + assert target_y.labels == ["u", "v"] + assert torch.allclose(target_y, target_graph[idx].y) + + +def test_getitems_tensor_input_tensor_target_condition_tensor(): + # Setup for standard torch.Tensor + input_tensor, target_tensor = _create_tensor_data(use_lt=False) + condition = Condition(input=input_tensor, target=target_tensor) + + indices = [1, 3, 5, 7] + items = condition[indices] + + # Verify values by comparing against manually stacked slices + expected_input = torch.stack([input_tensor[i] for i in indices]) + expected_target = torch.stack([target_tensor[i] for i in indices]) + + assert torch.allclose(items.input, expected_input) + assert torch.allclose(items.target, expected_target) + + # Ensure types remain standard torch.Tensor + assert isinstance(items.input, torch.Tensor) + assert not isinstance(items.input, LabelTensor) + assert isinstance(items.target, torch.Tensor) + + +def test_getitems_tensor_input_tensor_target_condition_label_tensor(): + # Setup for LabelTensor + input_tensor, target_tensor = _create_tensor_data(use_lt=True) + condition = Condition(input=input_tensor, target=target_tensor) + + indices = [1, 3, 5, 7] + items = condition[indices] + + # Assertions for Input LabelTensor + assert isinstance(items.input, LabelTensor) + assert items.input.labels == ["x", "y", "z"] + assert torch.allclose(items.input, input_tensor[indices]) + + # Assertions for Target LabelTensor + assert isinstance(items.target, LabelTensor) + assert items.target.labels == ["a", "b"] + assert torch.allclose(items.target, target_tensor[indices]) + + +def test_getitems_tensor_input_graph_target_condition_tensor(): + # Setup for standard torch.Tensor + target_graph, input_tensor = _create_graph_data(True, use_lt=False) + condition = Condition(input=input_tensor, target=target_graph) + + indices = [0, 2, 4] + items = condition[indices] + + # 1. Verify Input Batch (Tensor) + expected_input = torch.stack([input_tensor[i] for i in indices]) + assert torch.allclose(items.input, expected_input) + assert isinstance(items.input, torch.Tensor) + assert not isinstance(items.input, LabelTensor) + + # 2. Verify Target Batch (Graph List) + assert len(items.target) == len(indices) + for i, original_idx in enumerate(indices): + assert torch.allclose(items.target[i].y, target_graph[original_idx].y) + assert isinstance(items.target[i].y, torch.Tensor) + + +def test_getitems_tensor_input_graph_target_condition_label_tensor(): + # Setup for LabelTensor + target_graph, input_tensor = _create_graph_data(True, use_lt=True) + condition = Condition(input=input_tensor, target=target_graph) + + indices = [0, 2, 4] + items = condition[indices] + + # 1. Verify Input LabelTensor preservation + assert isinstance(items.input, LabelTensor) + assert items.input.labels == ["f"] + # Verify values still match + assert torch.allclose(items.input, input_tensor[indices]) + + # 2. Verify Target Graphs LabelTensor preservation + assert len(items.target) == len(indices) + for i, original_idx in enumerate(indices): + target_y = items.target[i].y + assert isinstance(target_y, LabelTensor) + assert target_y.labels == ["u", "v"] + # Verify numerical parity + assert torch.allclose(target_y, target_graph[original_idx].y) + + +def test_create_batch_tensor(): + input_tensor, target_tensor = _create_tensor_data() + condition = Condition(input=input_tensor, target=target_tensor) + idx = [0, 2, 4, 6] + data_to_collate = [condition.data[i] for i in idx] + batch = condition.automatic_batching_collate_fn(data_to_collate) + assert isinstance(batch, _BatchManager) + assert hasattr(batch, "input") + assert hasattr(batch, "target") + expected_input = torch.stack([input_tensor[i] for i in idx]) + expected_target = torch.stack([target_tensor[i] for i in idx]) + assert torch.allclose(batch.input, expected_input) + assert torch.allclose(batch.target, expected_target) + + batch = condition.collate_fn(idx, condition) + # assert isinstance(batch, _BatchManager) + assert hasattr(batch, "input") + assert hasattr(batch, "target") + expected_input = torch.stack([input_tensor[i] for i in idx]) + expected_target = torch.stack([target_tensor[i] for i in idx]) + assert torch.allclose(batch.input, expected_input) + assert torch.allclose(batch.target, expected_target) + + +def test_create_batch_graph(): + input_graph, target_tensor = _create_graph_data(False) + condition = Condition(input=input_graph, target=target_tensor) + idx = [1, 3, 5] + data_to_collate = [condition.data[i] for i in idx] + batch = condition.automatic_batching_collate_fn(data_to_collate) + assert isinstance(batch, _BatchManager) + assert hasattr(batch, "input") + assert hasattr(batch, "target") + expected_target = torch.cat([target_tensor[i] for i in idx]) + print(expected_target.shape, batch.target.shape) + assert torch.allclose(batch.target, expected_target) + assert batch.input.num_graphs == len(idx) + + batch = condition.collate_fn(idx, condition) + assert isinstance(batch, _BatchManager) + assert hasattr(batch, "input") + assert hasattr(batch, "target") + assert torch.allclose(batch.target, expected_target) + assert batch.input.num_graphs == len(idx) diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py new file mode 100644 index 000000000..af46c500d --- /dev/null +++ b/tests/test_data_manager.py @@ -0,0 +1,137 @@ +import torch +from pina._src.condition.data_manager import ( + _DataManager, + _TensorDataManager, + _GraphDataManager, +) +from pina.graph import Graph +from pina.equation import Equation + + +def test_tensor_data_manager_init(): + pippo = torch.rand((10, 5)) + pluto = torch.rand((10, 7)) + paperino = torch.rand((10, 11)) + data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + assert isinstance(data_manager, _TensorDataManager) + assert hasattr(data_manager, "pippo") + assert hasattr(data_manager, "pluto") + assert hasattr(data_manager, "paperino") + assert torch.equal(data_manager.pippo, pippo) + assert torch.equal(data_manager.pluto, pluto) + assert torch.equal(data_manager.paperino, paperino) + + paperino = Equation(lambda x: x**2) + data_manager3 = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + assert isinstance(data_manager3, _TensorDataManager) + assert hasattr(data_manager3, "pippo") + assert hasattr(data_manager3, "pluto") + assert hasattr(data_manager3, "paperino") + assert torch.equal(data_manager3.pippo, pippo) + assert torch.equal(data_manager3.pluto, pluto) + assert isinstance(data_manager3.paperino, Equation) + + +def test_graph_data_manager_init(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + assert hasattr(data_manager, "graph_key") + assert data_manager.graph_key == "graph" + assert hasattr(data_manager, "graph") + assert len(data_manager.data) == 3 + for i in range(3): + g = data_manager.graph[i] + assert torch.equal(g.x, x[i]) + assert torch.equal(g.pos, pos[i]) + assert torch.equal(g.edge_index, edge_index[i]) + assert torch.equal(g.target, target[i]) + + +def test_graph_data_manager_getattribute(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + target_retrieved = data_manager.target + assert torch.equal(target_retrieved, target) + + +def test_graph_data_manager_getitem(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + item = data_manager[1] + assert isinstance(item, _DataManager) + assert hasattr(item, "graph_key") + assert item.graph_key == "graph" + assert hasattr(item, "graph") + assert torch.equal(item.graph.x, x[1]) + assert torch.equal(item.graph.pos, pos[1]) + assert torch.equal(item.graph.edge_index, edge_index[1]) + assert torch.equal(item.target, target[1].unsqueeze(0)) + + +def test_graph_data_create_batch(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + item1 = data_manager[0] + item2 = data_manager[1] + batch_data = _GraphDataManager._create_batch([item1, item2]) + assert hasattr(batch_data, "graph") + assert hasattr(batch_data, "target") + batched_graphs = batch_data.graph + batched_target = batch_data.target + assert batched_graphs.num_graphs == 2 + assert batched_target.shape == (20, 1) + assert torch.equal(batched_target, torch.cat([target[0], target[1]], dim=0)) + mps_data = batch_data.to("mps") + assert mps_data.graph.num_graphs == 2 + assert torch.equal(mps_data.target, batched_target.to("mps")) + assert torch.equal(mps_data.graph.x, batched_graphs.x.to("mps")) + + +def test_tensor_data_create_batch(): + pippo = torch.rand((10, 5)) + pluto = torch.rand((10, 7)) + paperino = torch.rand((10, 11)) + data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + item1 = data_manager[0] + item2 = data_manager[1] + batch_data = _TensorDataManager._create_batch([item1, item2]) + assert hasattr(batch_data, "pippo") + assert hasattr(batch_data, "pluto") + assert hasattr(batch_data, "paperino") + assert torch.equal( + batch_data.pippo, torch.stack([pippo[0], pippo[1]], dim=0) + ) + assert torch.equal( + batch_data.pluto, torch.stack([pluto[0], pluto[1]], dim=0) + ) + assert torch.equal( + batch_data.paperino, torch.stack([paperino[0], paperino[1]], dim=0) + ) diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py index c5f0b9e52..4be2897d9 100644 --- a/tests/test_solver/test_ensemble_supervised_solver.py +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_models = [Models() for i in range(10)] diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab4d..461130a6b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_model = Model() From ab2c5c35adb8f2431a279f521a61ae52f8cc8c81 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 13 Feb 2026 09:32:29 +0100 Subject: [PATCH 2/3] DataModule refactoring (mathLab#766) --- pina/_src/condition/condition_base.py | 31 +- pina/_src/condition/data_manager.py | 1 + pina/_src/core/trainer.py | 55 ++- pina/_src/data/aggregator.py | 61 +++ pina/_src/data/creator.py | 182 +++++++ pina/_src/data/data_module.py | 629 ++++++------------------- pina/_src/data/dummy_dataloader.py | 62 +++ pina/_src/problem/abstract_problem.py | 96 ++-- pina/data/__init__.py | 18 - tests/test_data/test_data_module.py | 331 ------------- tests/test_data/test_graph_dataset.py | 138 ------ tests/test_data/test_tensor_dataset.py | 86 ---- tests/test_datamodule.py | 318 +++++++++++++ 13 files changed, 870 insertions(+), 1138 deletions(-) create mode 100644 pina/_src/data/aggregator.py create mode 100644 pina/_src/data/creator.py create mode 100644 pina/_src/data/dummy_dataloader.py delete mode 100644 tests/test_data/test_data_module.py delete mode 100644 tests/test_data/test_graph_dataset.py delete mode 100644 tests/test_data/test_tensor_dataset.py create mode 100644 tests/test_datamodule.py diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py index b8290d717..0d1a8cb15 100644 --- a/pina/_src/condition/condition_base.py +++ b/pina/_src/condition/condition_base.py @@ -9,6 +9,7 @@ from pina._src.condition.condition_interface import ConditionInterface from pina._src.core.graph import LabelBatch from pina._src.core.label_tensor import LabelTensor +from pina._src.data.dummy_dataloader import DummyDataloader class ConditionBase(ConditionInterface): @@ -33,6 +34,7 @@ def __init__(self, **kwargs): """ super().__init__() self.data = self.store_data(**kwargs) + self.has_custom_dataloader_fn = False @property def problem(self): @@ -85,7 +87,8 @@ def automatic_batching_collate_fn(cls, batch): if not batch: return {} instance_class = batch[0].__class__ - return instance_class.create_batch(batch) + batch = instance_class.create_batch(batch) + return batch @staticmethod def collate_fn(batch, condition): @@ -103,7 +106,11 @@ def collate_fn(batch, condition): return data def create_dataloader( - self, dataset, batch_size, shuffle, automatic_batching + self, + dataset, + batch_size, + automatic_batching, + **kwargs, ): """ Create a DataLoader for the condition. @@ -114,14 +121,28 @@ def create_dataloader( :rtype: torch.utils.data.DataLoader """ if batch_size == len(dataset): - pass # will be updated in the near future + return DummyDataloader(dataset) return DataLoader( dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, collate_fn=( partial(self.collate_fn, condition=self) if not automatic_batching else self.automatic_batching_collate_fn ), + batch_size=batch_size, + **kwargs, ) + + def switch_dataloader_fn(self, create_dataloader_fn): + """ + Decorator to switch the dataloader function for a condition. + + :param create_dataloader_fn: The new dataloader function to use. + :type create_dataloader_fn: function + :return: The decorated function with the new dataloader function. + :rtype: function + """ + # Replace the create_dataloader method of the ConditionBase class with + # the new function + self.has_custom_dataloader_fn = True + self.create_dataloader = create_dataloader_fn diff --git a/pina/_src/condition/data_manager.py b/pina/_src/condition/data_manager.py index b390cb580..2d80a5b6f 100644 --- a/pina/_src/condition/data_manager.py +++ b/pina/_src/condition/data_manager.py @@ -119,6 +119,7 @@ def create_batch(items): if isinstance(sample, LabelTensor) else torch.stack ) + batch_data[k] = batch_fn(vals) batch_data[k] = batch_fn(vals, dim=0) else: batch_data[k] = sample diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py index 7500be537..d18350d14 100644 --- a/pina/_src/core/trainer.py +++ b/pina/_src/core/trainer.py @@ -36,7 +36,7 @@ def __init__( test_size=0.0, val_size=0.0, compile=None, - repeat=None, + batching_mode="common_batch_size", automatic_batching=None, num_workers=None, pin_memory=None, @@ -61,9 +61,9 @@ def __init__( :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. Not supported for python version greater or equal than 3.14. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. Default is ``"common_batch_size"``. ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset @@ -87,7 +87,7 @@ def __init__( train_size=train_size, test_size=test_size, val_size=val_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, compile=compile, ) @@ -127,24 +127,44 @@ def __init__( UserWarning, ) - repeat = repeat if repeat is not None else False - automatic_batching = ( automatic_batching if automatic_batching is not None else False ) + if batch_size is None and batching_mode != "common_batch_size": + warnings.warn( + "Batching mode is set to " + f"{batching_mode} but batch_size is None. " + "Batching mode will be set to common_batch_size.", + UserWarning, + ) + batching_mode = "common_batch_size" + + if ( + batch_size is not None + and batch_size <= len(solver.problem.conditions) + and batching_mode == "proportional" + ): + warnings.warn( + "Batching mode is set to proportional but batch_size is 1. " + "Batching mode will be set to common_batch_size.", + UserWarning, + ) + batching_mode = "common_batch_size" + # set attributes self.compile = compile self.solver = solver self.batch_size = batch_size self._move_to_device() self.data_module = None + self._create_datamodule( train_size=train_size, test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, @@ -182,7 +202,7 @@ def _create_datamodule( test_size, val_size, batch_size, - repeat, + batching_mode, automatic_batching, pin_memory, num_workers, @@ -201,8 +221,9 @@ def _create_datamodule( :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data @@ -232,7 +253,7 @@ def _create_datamodule( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -284,7 +305,7 @@ def _check_input_consistency( train_size, test_size, val_size, - repeat, + batching_mode, automatic_batching, compile, ): @@ -298,8 +319,9 @@ def _check_input_consistency( test dataset. :param float val_size: The percentage of elements to include in the validation dataset. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. @@ -309,8 +331,7 @@ def _check_input_consistency( check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - if repeat is not None: - check_consistency(repeat, bool) + check_consistency(batching_mode, str) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: diff --git a/pina/_src/data/aggregator.py b/pina/_src/data/aggregator.py new file mode 100644 index 000000000..605af5d46 --- /dev/null +++ b/pina/_src/data/aggregator.py @@ -0,0 +1,61 @@ +""" +Aggregator for multiple dataloaders. +""" + + +class _Aggregator: + """ + The class :class:`_Aggregator` is responsible for aggregating multiple + dataloaders into a single iterable object. It supports different batching + modes to accommodate various training requirements. + """ + + def __init__(self, dataloaders, batching_mode): + """ + Initialization of the :class:`_Aggregator` class. + + :param dataloaders: A dictionary mapping condition names to their + respective dataloaders. + :type dataloaders: dict[str, DataLoader] + :param batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. + :type batching_mode: str + """ + self.dataloaders = dataloaders + self.batching_mode = batching_mode + + def __len__(self): + """ + Return the length of the aggregated dataloader. + + :return: The length of the aggregated dataloader. + :rtype: int + """ + if self.batching_mode == "separate_conditions": + return sum(len(dl) for dl in self.dataloaders.values()) + return max(len(dl) for dl in self.dataloaders.values()) + + def __iter__(self): + """ + Return an iterator over the aggregated dataloader. + + :return: An iterator over the aggregated dataloader. + :rtype: iterator + """ + if self.batching_mode == "separate_conditions": + # TODO: implement separate_conditions batching mode + raise NotImplementedError( + "Batching mode 'separate_conditions' is not implemented yet." + ) + + iterators = {name: iter(dl) for name, dl in self.dataloaders.items()} + for _ in range(len(self)): + batch = {} + for name, it in iterators.items(): + try: + batch[name] = next(it) + except StopIteration: + iterators[name] = iter(self.dataloaders[name]) + batch[name] = next(iterators[name]) + yield batch diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py new file mode 100644 index 000000000..0e84aef72 --- /dev/null +++ b/pina/_src/data/creator.py @@ -0,0 +1,182 @@ +""" +Module defining the Creator class, responsible for creating dataloaders +for multiple conditions with various batching strategies. +""" + +import torch +from torch.utils.data import RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + + +class _Creator: + """ + The class :class:`_Creator` is responsible for creating dataloaders for + multiple conditions based on specified batching strategies. It supports + different batching modes to accommodate various training requirements. + """ + + def __init__( + self, + batching_mode, + batch_size, + shuffle, + automatic_batching, + num_workers, + pin_memory, + conditions, + ): + """ + Initialization of the :class:`_Creator` class. + + :param batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. + :type batching_mode: str + :param batch_size: The batch size to use for dataloaders. If + ``batching_mode`` is ``"proportional"``, this represents the total + batch size across all conditions. + :type batch_size: int | None + :param shuffle: Whether to shuffle the data in the dataloaders. + :type shuffle: bool + :param automatic_batching: Whether to use automatic batching in the + dataloaders. + :type automatic_batching: bool + :param num_workers: The number of worker processes to use for data + loading. + :type num_workers: int + :param pin_memory: Whether to pin memory in the dataloaders. + :type pin_memory: bool + :param conditions: A dictionary mapping condition names to their + respective condition objects. + :type conditions: dict[str, Condition] + """ + self.batching_mode = batching_mode + self.batch_size = batch_size + self.shuffle = shuffle + self.automatic_batching = automatic_batching + self.num_workers = num_workers + self.pin_memory = pin_memory + self.conditions = conditions + + def _define_sampler(self, dataset, shuffle): + if torch.distributed.is_initialized(): + return DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + return RandomSampler(dataset) + return SequentialSampler(dataset) + + def _compute_batch_sizes(self, datasets): + """ + Compute batch sizes for each condition based on the specified + batching mode. + + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their computed batch + sizes. + :rtype: dict[str, int] + """ + batch_sizes = {} + if self.batching_mode == "common_batch_size": + for name in datasets.keys(): + if self.batch_size is None: + batch_sizes[name] = len(datasets[name]) + else: + batch_sizes[name] = min( + self.batch_size, len(datasets[name]) + ) + return batch_sizes + if self.batching_mode == "proportional": + return self._compute_proportional_batch_sizes(datasets) + if self.batching_mode == "separate_conditions": + for name in datasets.keys(): + condition = self.conditions[name] + if self.batch_size is None: + batch_sizes[name] = len(datasets[name]) + else: + batch_sizes[name] = min( + self.batch_size, len(datasets[name]) + ) + return batch_sizes + raise ValueError(f"Unknown batching mode: {self.batching_mode}") + + def _compute_proportional_batch_sizes(self, datasets): + """ + Compute batch sizes for each condition proportionally based on the + size of their datasets. + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their computed batch + sizes. + :rtype: dict[str, int] + """ + # Compute number of elements per dataset + elements_per_dataset = { + dataset_name: len(dataset) + for dataset_name, dataset in datasets.items() + } + # Compute the total number of elements + total_elements = sum(el for el in elements_per_dataset.values()) + # Compute the portion of each dataset + portion_per_dataset = { + name: el / total_elements + for name, el in elements_per_dataset.items() + } + # Compute batch size per dataset. Ensure at least 1 element per + # dataset. + batch_size_per_dataset = { + name: max(1, int(portion * self.batch_size)) + for name, portion in portion_per_dataset.items() + } + # Adjust batch sizes to match the specified total batch size + tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) + if self.batch_size > tot_el_per_batch: + difference = self.batch_size - tot_el_per_batch + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] += 1 + difference -= 1 + if self.batch_size < tot_el_per_batch: + difference = tot_el_per_batch - self.batch_size + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] -= 1 + difference -= 1 + return batch_size_per_dataset + + def __call__(self, datasets): + """ + Create dataloaders for each condition based on the specified batching + mode. + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their created + dataloaders. + :rtype: dict[str, DataLoader] + """ + # Compute batch sizes per condition based on batching_mode + batch_sizes = self._compute_batch_sizes(datasets) + dataloaders = {} + if self.batching_mode == "common_batch_size": + max_len = max(len(dataset) for dataset in datasets.values()) + for name, dataset in datasets.items(): + if self.batching_mode == "common_batch_size": + dataset.max_len = max_len + dataloaders[name] = self.conditions[name].create_dataloader( + dataset=dataset, + batch_size=batch_sizes[name], + automatic_batching=self.automatic_batching, + sampler=self._define_sampler(dataset, self.shuffle), + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return dataloaders diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index f45236f0f..d0fb5989a 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -7,232 +7,58 @@ import warnings from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Data -from torch.utils.data import DataLoader, SequentialSampler -from torch.utils.data.distributed import DistributedSampler -from pina._src.core.label_tensor import LabelTensor -from pina._src.data.dataset import PinaDatasetFactory, PinaTensorDataset +from torch_geometric.data import Batch +from pina._src.data.creator import _Creator +from pina._src.core.graph import LabelBatch, Graph +from pina._src.data.aggregator import _Aggregator -class DummyDataloader: - - def __init__(self, dataset): - """ - Prepare a dataloader object that returns the entire dataset in a single - batch. Depending on the number of GPUs, the dataset is managed - as follows: - - - **Distributed Environment** (multiple GPUs): Divides dataset across - processes using the rank and world size. Fetches only portion of - data corresponding to the current process. - - **Non-Distributed Environment** (single GPU): Fetches the entire - dataset. - - :param PinaDataset dataset: The dataset object to be processed. - - .. note:: - This dataloader is used when the batch size is ``None``. - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if len(dataset) < world_size: - raise RuntimeError( - "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU" - ) - idx, i = [], rank - while i < len(dataset): - idx.append(i) - i += world_size - self.dataset = dataset.fetch_from_idx_list(idx) - else: - self.dataset = dataset.get_all_data() - - def __iter__(self): - return self - - def __len__(self): - return 1 - - def __next__(self): - return self.dataset - - -class Collator: +class _ConditionSubset: """ - This callable class is used to collate the data points fetched from the - dataset. The collation is performed based on the type of dataset used and - on the batching strategy. + This class extends the :class:`torch.utils.data.Subset` class, allowing to + fetch the data from the dataset based on a list of indices. """ - def __init__( - self, max_conditions_lengths, automatic_batching, dataset=None - ): - """ - Initialize the object, setting the collate function based on whether - automatic batching is enabled or not. - - :param dict max_conditions_lengths: ``dict`` containing the maximum - number of data points to consider in a single batch for - each condition. - :param bool automatic_batching: Whether automatic PyTorch batching is - enabled or not. For more information, see the - :class:`~pina.data.data_module.PinaDataModule` class. - :param PinaDataset dataset: The dataset where the data is stored. - """ - - self.max_conditions_lengths = max_conditions_lengths - # Set the collate function based on the batching strategy - # collate_pina_dataloader is used when automatic batching is disabled - # collate_torch_dataloader is used when automatic batching is enabled - self.callable_function = ( - self._collate_torch_dataloader - if automatic_batching - else (self._collate_pina_dataloader) - ) - self.dataset = dataset - - # Set the function which performs the actual collation - if isinstance(self.dataset, PinaTensorDataset): - # If the dataset is a PinaTensorDataset, use this collate function - self._collate = self._collate_tensor_dataset - else: - # If the dataset is a PinaDataset, use this collate function - self._collate = self._collate_graph_dataset - - def _collate_pina_dataloader(self, batch): - """ - Function used to create a batch when automatic batching is disabled. - - :param list[int] batch: List of integers representing the indices of - the data points to be fetched. - :return: Dictionary containing the data points fetched from the dataset. - :rtype: dict - """ - # Call the fetch_from_idx_list method of the dataset - return self.dataset.fetch_from_idx_list(batch) - - def _collate_torch_dataloader(self, batch): - """ - Function used to collate the batch - - :param list[dict] batch: List of retrieved data. - :return: Dictionary containing the data points fetched from the dataset, - collated. - :rtype: dict - """ - - batch_dict = {} - if isinstance(batch, dict): - return batch - conditions_names = batch[0].keys() - # Condition names - for condition_name in conditions_names: - single_cond_dict = {} - condition_args = batch[0][condition_name].keys() - for arg in condition_args: - data_list = [ - batch[idx][condition_name][arg] - for idx in range( - min( - len(batch), - self.max_conditions_lengths[condition_name], - ) - ) - ] - single_cond_dict[arg] = self._collate(data_list) - - batch_dict[condition_name] = single_cond_dict - return batch_dict - - @staticmethod - def _collate_tensor_dataset(data_list): - """ - Function used to collate the data when the dataset is a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param data_list: Elements to be collated. - :type data_list: list[torch.Tensor] | list[LabelTensor] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - """ - - if isinstance(data_list[0], LabelTensor): - return LabelTensor.stack(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.stack(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor ") - - def _collate_graph_dataset(self, data_list): - """ - Function used to collate data when the dataset is a - :class:`~pina.data.dataset.PinaGraphDataset`. - - :param data_list: Elememts to be collated. - :type data_list: list[Data] | list[Graph] - :return: Batch of data. - :rtype: dict + def __init__(self, condition, indices, automatic_batching): + super().__init__() + self.condition = condition + self.indices = indices + self.automatic_batching = automatic_batching + self.length = len(self.indices) + self.max_len = self.length - :raises RuntimeError: If the data is not a - :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. - """ - if isinstance(data_list[0], LabelTensor): - return LabelTensor.cat(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.cat(data_list) - if isinstance(data_list[0], Data): - return self.dataset.create_batch(data_list) - raise RuntimeError( - "Data must be Tensors or LabelTensor or pyG " - "torch_geometric.data.Data" - ) + def __len__(self): + return self.max_len - def __call__(self, batch): + def __getitem__(self, idx): """ - Perform the collation of data fetched from the dataset. The behavoior - of the function is set based on the batching strategy during class - initialization. + Fetch the data from the dataset based on the list of indices. - :param batch: List of retrieved data or sampled indices. - :type batch: list[int] | list[dict] - :return: Dictionary containing colleted data fetched from the dataset. + :param int idx: The index of the data to be fetched. + :return: The data corresponding to the given index. :rtype: dict """ - - return self.callable_function(batch) - - -class PinaSampler: - """ - This class is used to create the sampler instance based on the shuffle - parameter and the environment in which the code is running. - """ - - def __new__(cls, dataset): - """ - Instantiate and initialize the sampler. - - :param PinaDataset dataset: The dataset from which to sample. - :return: The sampler instance. - :rtype: :class:`torch.utils.data.Sampler` - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - sampler = DistributedSampler(dataset) - else: - sampler = SequentialSampler(dataset) - return sampler + if idx >= self.length: + idx = idx % self.length + idx = self.indices[idx] + if not self.automatic_batching: + return idx + return self.condition[idx] + + def get_all_data(self): + data = self.condition[self.indices] + if "data" in data and isinstance(data["data"], list): + batch_fn = ( + LabelBatch.from_data_list + if isinstance(data["data"][0], Graph) + else Batch.from_data_list + ) + data["data"] = batch_fn(data["data"]) + data = { + "input": data["data"], + "target": data["data"].y, + } + return data class PinaDataModule(LightningDataModule): @@ -250,7 +76,7 @@ def __init__( val_size=0.1, batch_size=None, shuffle=True, - repeat=False, + batching_mode="common_batch_size", automatic_batching=None, num_workers=0, pin_memory=False, @@ -271,11 +97,9 @@ def __init__( Default is ``None``. :param bool shuffle: Whether to shuffle the dataset before splitting. Default ``True``. - :param bool repeat: If ``True``, in case of batch size larger than the - number of elements in a specific condition, the elements are - repeated until the batch size is reached. If ``False``, the number - of elements in the batch is the minimum between the batch size and - the number of elements in the condition. Default is ``False``. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. Default is ``"common_batch_size"``. :param automatic_batching: If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful @@ -302,11 +126,13 @@ def __init__( """ super().__init__() + self.problem = problem # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle - self.repeat = repeat + self.batching_mode = batching_mode self.automatic_batching = automatic_batching + self.batching_mode = batching_mode # If batch size is None, num_workers has no effect if batch_size is None and num_workers != 0: @@ -327,41 +153,87 @@ def __init__( self.pin_memory = False else: self.pin_memory = pin_memory - - # Collect data - problem.collect_data() - - # Check if the splits are correct + self.problem.move_discretisation_into_conditions() self._check_slit_sizes(train_size, test_size, val_size) - # Split input data into subsets - splits_dict = {} if train_size > 0: - splits_dict["train"] = train_size self.train_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.train_dataloader = super().train_dataloader if test_size > 0: - splits_dict["test"] = test_size self.test_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.test_dataloader = super().test_dataloader if val_size > 0: - splits_dict["val"] = val_size self.val_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.val_dataloader = super().val_dataloader - self.data_splits = self._create_splits( - problem.collected_data, splits_dict + self._create_condition_splits(problem, train_size, test_size, val_size) + self.creator = _Creator( + batching_mode=batching_mode, + batch_size=batch_size, + shuffle=shuffle, + automatic_batching=automatic_batching, + num_workers=num_workers, + pin_memory=pin_memory, + conditions=problem.conditions, ) - self.transfer_batch_to_device = self._transfer_batch_to_device + + @staticmethod + def _check_slit_sizes(train_size, test_size, val_size): + """ + Check if the splits are correct. The splits sizes must be positive and + the sum of the splits must be 1. + + :param float train_size: The size of the training split. + :param float test_size: The size of the testing split. + :param float val_size: The size of the validation split. + + :raises ValueError: If at least one of the splits is negative. + :raises ValueError: If the sum of the splits is different + from 1. + """ + + if train_size < 0 or test_size < 0 or val_size < 0: + raise ValueError("The splits must be positive") + if abs(train_size + test_size + val_size - 1) > 1e-6: + raise ValueError("The sum of the splits must be 1") + + def _create_condition_splits( + self, problem, train_size, test_size, val_size + ): + self.split_idxs = {} + for condition_name, condition in problem.conditions.items(): + len_condition = len(condition) + # Create the indices for shuffling and splitting + indices = ( + torch.randperm(len_condition).tolist() + if self.shuffle + else list(range(len_condition)) + ) + + # Determine split sizes + train_end = int(train_size * len_condition) + test_end = train_end + int(test_size * len_condition) + + # Split indices + train_indices = indices[:train_end] + test_indices = indices[train_end:test_end] + val_indices = indices[test_end:] + splits = {} + splits["train"], splits["test"], splits["val"] = ( + train_indices, + test_indices, + val_indices, + ) + self.split_idxs[condition_name] = splits def setup(self, stage=None): """ @@ -373,210 +245,58 @@ def setup(self, stage=None): :raises ValueError: If the stage is neither "fit" nor "test". """ - if stage == "fit" or stage is None: - self.train_dataset = PinaDatasetFactory( - self.data_splits["train"], - max_conditions_lengths=self.find_max_conditions_lengths( - "train" - ), - automatic_batching=self.automatic_batching, - ) - if "val" in self.data_splits.keys(): - self.val_dataset = PinaDatasetFactory( - self.data_splits["val"], - max_conditions_lengths=self.find_max_conditions_lengths( - "val" - ), + if stage in ("fit", None): + self.train_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["train"], automatic_batching=self.automatic_batching, ) - elif stage == "test": - self.test_dataset = PinaDatasetFactory( - self.data_splits["test"], - max_conditions_lengths=self.find_max_conditions_lengths("test"), - automatic_batching=self.automatic_batching, - ) - else: - raise ValueError("stage must be either 'fit' or 'test'.") - - @staticmethod - def _split_condition(single_condition_dict, splits_dict): - """ - Split the condition into different stages. - - :param dict single_condition_dict: The condition to be split. - :param dict splits_dict: The dictionary containing the number of - elements in each stage. - :return: A dictionary containing the split condition. - :rtype: dict - """ - - len_condition = len(single_condition_dict["input"]) - - lengths = [ - int(len_condition * length) for length in splits_dict.values() - ] - - remainder = len_condition - sum(lengths) - for i in range(remainder): - lengths[i % len(lengths)] += 1 - - splits_dict = { - k: max(1, v) for k, v in zip(splits_dict.keys(), lengths) - } - to_return_dict = {} - offset = 0 - - for stage, stage_len in splits_dict.items(): - to_return_dict[stage] = { - k: v[offset : offset + stage_len] - for k, v in single_condition_dict.items() - if k != "equation" - # Equations are NEVER dataloaded + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["train"]) > 0 } - if offset + stage_len >= len_condition: - offset = len_condition - 1 - continue - offset += stage_len - return to_return_dict - - def _create_splits(self, collector, splits_dict): - """ - Create the dataset objects putting data in the correct splits. - :param Collector collector: The collector object containing the data. - :param dict splits_dict: The dictionary containing the number of - elements in each stage. - :return: The dictionary containing the dataset objects. - :rtype: dict - """ - - # ----------- Auxiliary function ------------ - def _apply_shuffle(condition_dict, len_data): - idx = torch.randperm(len_data) - for k, v in condition_dict.items(): - if k == "equation": - continue - if isinstance(v, list): - condition_dict[k] = [v[i] for i in idx] - elif isinstance(v, LabelTensor): - condition_dict[k] = LabelTensor(v.tensor[idx], v.labels) - elif isinstance(v, torch.Tensor): - condition_dict[k] = v[idx] - else: - raise ValueError(f"Data type {type(v)} not supported") - - # ----------- End auxiliary function ------------ - - split_names = list(splits_dict.keys()) - dataset_dict = {name: {} for name in split_names} - for ( - condition_name, - condition_dict, - ) in collector.items(): - len_data = len(condition_dict["input"]) - if self.shuffle: - _apply_shuffle(condition_dict, len_data) - for key, data in self._split_condition( - condition_dict, splits_dict - ).items(): - dataset_dict[key].update({condition_name: data}) - return dataset_dict - - def _create_dataloader(self, split, dataset): - """ " - Create the dataloader for the given split. - - :param str split: The split on which to create the dataloader. - :param str dataset: The dataset to be used for the dataloader. - :return: The dataloader for the given split. - :rtype: torch.utils.data.DataLoader - """ - # Suppress the warning about num_workers. - # In many cases, especially for PINNs, - # serial data loading can outperform parallel data loading. - warnings.filterwarnings( - "ignore", - message=( - "The '(train|val|test)_dataloader' does not have many workers " - "which may be a bottleneck." - ), - module="lightning.pytorch.trainer.connectors.data_connector", - ) - # Use custom batching (good if batch size is large) - if self.batch_size is not None: - sampler = PinaSampler(dataset) - if self.automatic_batching: - collate = Collator( - self.find_max_conditions_lengths(split), - self.automatic_batching, - dataset=dataset, - ) - else: - collate = Collator( - None, self.automatic_batching, dataset=dataset + self.val_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["val"], + automatic_batching=self.automatic_batching, ) - return DataLoader( - dataset, - self.batch_size, - collate_fn=collate, - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - ) - dataloader = DummyDataloader(dataset) - dataloader.dataset = self._transfer_batch_to_device( - dataloader.dataset, self.trainer.strategy.root_device, 0 - ) - self.transfer_batch_to_device = self._transfer_batch_to_device_dummy - return dataloader - - def find_max_conditions_lengths(self, split): - """ - Define the maximum length for each conditions. - - :param dict split: The split of the dataset. - :return: The maximum length per condition. - :rtype: dict - """ + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["val"]) > 0 + } - max_conditions_lengths = {} - for k, v in self.data_splits[split].items(): - if self.batch_size is None: - max_conditions_lengths[k] = len(v["input"]) - elif self.repeat: - max_conditions_lengths[k] = self.batch_size - else: - max_conditions_lengths[k] = min( - len(v["input"]), self.batch_size + if stage in ("test", None): + self.test_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["test"], + automatic_batching=self.automatic_batching, ) - return max_conditions_lengths - - def val_dataloader(self): - """ - Create the validation dataloader. - - :return: The validation dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("val", self.val_dataset) + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["test"]) > 0 + } + if stage not in ("fit", "test", None): + raise ValueError( + f"Invalid stage {stage}. Stage must be either 'fit' or 'test'." + ) def train_dataloader(self): - """ - Create the training dataloader + return _Aggregator( + self.creator(self.train_datasets), + batching_mode=self.batching_mode, + ) - :return: The training dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("train", self.train_dataset) + def val_dataloader(self): + return _Aggregator( + self.creator(self.val_datasets), batching_mode=self.batching_mode + ) def test_dataloader(self): - """ - Create the testing dataloader - - :return: The testing dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("test", self.test_dataset) + return _Aggregator( + self.creator(self.test_datasets), + batching_mode=self.batching_mode, + ) @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): @@ -591,10 +311,9 @@ def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): :return: The batch transferred to the device. :rtype: list[tuple] """ - return batch - def _transfer_batch_to_device(self, batch, device, dataloader_idx): + def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. @@ -606,53 +325,7 @@ def _transfer_batch_to_device(self, batch, device, dataloader_idx): :return: The batch transferred to the device. :rtype: list[tuple] """ - - batch = [ - ( - k, - super(LightningDataModule, self).transfer_batch_to_device( - v, device, dataloader_idx - ), - ) - for k, v in batch.items() - ] - - return batch - - @staticmethod - def _check_slit_sizes(train_size, test_size, val_size): - """ - Check if the splits are correct. The splits sizes must be positive and - the sum of the splits must be 1. - - :param float train_size: The size of the training split. - :param float test_size: The size of the testing split. - :param float val_size: The size of the validation split. - - :raises ValueError: If at least one of the splits is negative. - :raises ValueError: If the sum of the splits is different - from 1. - """ - - if train_size < 0 or test_size < 0 or val_size < 0: - raise ValueError("The splits must be positive") - if abs(train_size + test_size + val_size - 1) > 1e-6: - raise ValueError("The sum of the splits must be 1") - - @property - def input(self): - """ - Return all the input points coming from all the datasets. - - :return: The input points for training. - :rtype: dict - """ - - to_return = {} - if hasattr(self, "train_dataset") and self.train_dataset is not None: - to_return["train"] = self.train_dataset.input - if hasattr(self, "val_dataset") and self.val_dataset is not None: - to_return["val"] = self.val_dataset.input - if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return["test"] = self.test_dataset.input + to_return = [] + for condition_name, condition in batch.items(): + to_return.append((condition_name, condition.to(device))) return to_return diff --git a/pina/_src/data/dummy_dataloader.py b/pina/_src/data/dummy_dataloader.py new file mode 100644 index 000000000..c236e9d30 --- /dev/null +++ b/pina/_src/data/dummy_dataloader.py @@ -0,0 +1,62 @@ +""" +Module containing the ``DummyDataloader`` class +""" + +import torch + + +class DummyDataloader: + """ + A dummy dataloader that returns the entire dataset in a single batch. This + is used when the batch size is ``None``. It supports both distributed and + non-distributed environments. In a distributed environment, it divides the + dataset across processes using the rank and world size, fetching only the + portion of data corresponding to the current process. In a non-distributed + environment, it fetches the entire dataset. + """ + + def __init__(self, dataset): + """ + Prepare a dataloader object that returns the entire dataset in a single + batch. Depending on the number of GPUs, the dataset is managed + as follows: + + - **Distributed Environment** (multiple GPUs): Divides dataset across + processes using the rank and world size. Fetches only portion of + data corresponding to the current process. + - **Non-Distributed Environment** (single GPU): Fetches the entire + dataset. + + :param PinaDataset dataset: The dataset object to be processed. + + .. note:: + This dataloader is used when the batch size is ``None``. + """ + + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + if len(dataset) < world_size: + raise RuntimeError( + "Dimension of the dataset smaller than world size." + " Increase the size of the partition or use a single GPU" + ) + idx, i = [], rank + while i < len(dataset): + idx.append(i) + i += world_size + self.dataset = dataset.fetch_from_idx_list(idx).to_batch() + else: + self.dataset = dataset.get_all_data().to_batch() + + def __iter__(self): + return self + + def __len__(self): + return 1 + + def __next__(self): + return self.dataset diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index cfaeb5bec..cc2b9e042 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -11,6 +11,7 @@ ) from pina._src.core.label_tensor import LabelTensor from pina._src.core.utils import merge_tensors, custom_warning_format +from pina._src.condition.condition import Condition class AbstractProblem(metaclass=ABCMeta): @@ -42,43 +43,6 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - self._collected_data = {} - - @property - def collected_data(self): - """ - Return the collected data from the problem's conditions. If some domains - are not sampled, they will not be returned by collected data. - - :return: The collected data. Keys are condition names, and values are - dictionaries containing the input points and the corresponding - equations or target points. - :rtype: dict - """ - # collect data so far - self.collect_data() - # raise warning if some sample data are missing - if not self.are_all_domains_discretised: - warnings.formatwarning = custom_warning_format - warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { - "sampled" if key in self.discretised_domains - else - "not sampled"}""" - for key in self.domains - ] - ) - warnings.warn( - "Some of the domains are still not sampled. Consider calling " - "problem.discretise_domain function for all domains before " - "accessing the collected data:\n" - f"{warning_message}", - RuntimeWarning, - ) - return self._collected_data - # back compatibility 0.1 @property def input_pts(self): @@ -318,34 +282,36 @@ def add_points(self, new_points_dict): [self.discretised_domains[k], v] ) - def collect_data(self): + def move_discretisation_into_conditions(self): """ - Aggregate data from the problem's conditions into a single dictionary. + Move the discretised domains into their corresponding conditions. """ - data = {} - # Iterate over the conditions and collect data - for condition_name in self.conditions: - condition = self.conditions[condition_name] - # Check if the condition has an domain attribute - if hasattr(condition, "domain"): - # Only store the discretisation points if the domain is - # in the dictionary - if condition.domain in self.discretised_domains: - samples = self.discretised_domains[condition.domain][ - self.input_variables - ] - data[condition_name] = { - "input": samples, - "equation": condition.equation, - } - else: - # If the condition does not have a domain attribute, store - # the input and target points - keys = condition.__slots__ - values = [ - getattr(condition, name) - for name in keys - if getattr(condition, name) is not None + if not self.are_all_domains_discretised: + warnings.formatwarning = custom_warning_format + warnings.filterwarnings("always", category=RuntimeWarning) + warning_message = "\n".join( + [ + f"""{" " * 13} ---> Domain {key} { + "sampled" if key in self.discretised_domains + else + "not sampled"}""" + for key in self.domains ] - data[condition_name] = dict(zip(keys, values)) - self._collected_data = data + ) + warnings.warn( + "Some of the domains are still not sampled. Consider calling " + "problem.discretise_domain function for all domains before " + "accessing the collected data:\n" + f"{warning_message}", + RuntimeWarning, + ) + + for name, cond in self.conditions.items(): + if hasattr(cond, "domain"): + domain = cond.domain + self.conditions[name] = Condition( + input=self.discretised_domains[cond.domain], + equation=cond.equation, + ) + self.conditions[name].domain = domain + self.conditions[name].problem = self diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 2ecebecdd..f274d5bd9 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -7,26 +7,8 @@ from pina._src.data.data_module import ( PinaDataModule, - PinaSampler, - DummyDataloader, - Collator, - PinaSampler, -) - -from pina._src.data.dataset import ( - PinaDataset, - PinaTensorDataset, - PinaGraphDataset, - PinaDatasetFactory, ) __all__ = [ "PinaDataModule", - "PinaDataset", - "PinaSampler", - "DummyDataloader", - "Collator", - "PinaTensorDataset", - "PinaGraphDataset", - "PinaDatasetFactory", ] diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py deleted file mode 100644 index 9fd2d36ee..000000000 --- a/tests/test_data/test_data_module.py +++ /dev/null @@ -1,331 +0,0 @@ -import torch -import pytest -from pina.data import PinaDataModule -from pina.data import PinaTensorDataset, PinaGraphDataset -from pina.problem.zoo import SupervisedProblem -from pina.graph import RadiusGraph -from pina.data import DummyDataloader -from pina import Trainer -from pina.solver import SupervisedSolver -from torch_geometric.data import Batch -from torch.utils.data import DataLoader - -input_tensor = torch.rand((100, 10)) -output_tensor = torch.rand((100, 2)) - -x = torch.rand((100, 50, 10)) -pos = torch.rand((100, 50, 2)) -input_graph = [ - RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) -] -output_graph = torch.rand((100, 50, 10)) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -def test_constructor(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) - PinaDataModule(problem) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize( - "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.7, 0.3, 0)] -) -def test_setup_train(input_, output_, train_size, val_size, test_size): - problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule( - problem, train_size=train_size, val_size=val_size, test_size=test_size - ) - dm.setup() - assert hasattr(dm, "train_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.train_dataset, PinaTensorDataset) - else: - assert isinstance(dm.train_dataset, PinaGraphDataset) - # assert len(dm.train_dataset) == int(len(input_) * train_size) - if test_size > 0: - assert hasattr(dm, "test_dataset") - assert dm.test_dataset is None - else: - assert not hasattr(dm, "test_dataset") - assert hasattr(dm, "val_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.val_dataset, PinaTensorDataset) - else: - assert isinstance(dm.val_dataset, PinaGraphDataset) - # assert len(dm.val_dataset) == int(len(input_) * val_size) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize( - "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.0, 0.0, 1.0)] -) -def test_setup_test(input_, output_, train_size, val_size, test_size): - problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule( - problem, train_size=train_size, val_size=val_size, test_size=test_size - ) - dm.setup(stage="test") - if train_size > 0: - assert hasattr(dm, "train_dataset") - assert dm.train_dataset is None - else: - assert not hasattr(dm, "train_dataset") - if val_size > 0: - assert hasattr(dm, "val_dataset") - assert dm.val_dataset is None - else: - assert not hasattr(dm, "val_dataset") - - assert hasattr(dm, "test_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.test_dataset, PinaTensorDataset) - else: - assert isinstance(dm.test_dataset, PinaGraphDataset) - # assert len(dm.test_dataset) == int(len(input_) * test_size) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -def test_dummy_dataloader(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0 - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DummyDataloader) - assert len(dataloader) == 1 - data = next(dataloader) - assert isinstance(data, list) - assert isinstance(data[0], tuple) - if isinstance(input_, list): - assert isinstance(data[0][1]["input"], Batch) - else: - assert isinstance(data[0][1]["input"], torch.Tensor) - assert isinstance(data[0][1]["target"], torch.Tensor) - - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DummyDataloader) - assert len(dataloader) == 1 - data = next(dataloader) - assert isinstance(data, list) - assert isinstance(data[0], tuple) - if isinstance(input_, list): - assert isinstance(data[0][1]["input"], Batch) - else: - assert isinstance(data[0][1]["input"], torch.Tensor) - assert isinstance(data[0][1]["target"], torch.Tensor) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize("automatic_batching", [True, False]) -def test_dataloader(input_, output_, automatic_batching): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, - batch_size=10, - train_size=0.7, - val_size=0.3, - test_size=0.0, - automatic_batching=automatic_batching, - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 7 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - else: - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["target"], torch.Tensor) - - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 3 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - else: - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["target"], torch.Tensor) - - -from pina import LabelTensor - -input_tensor = LabelTensor(torch.rand((100, 3)), ["u", "v", "w"]) -output_tensor = LabelTensor(torch.rand((100, 3)), ["u", "v", "w"]) - -x = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) -pos = LabelTensor(torch.rand((100, 50, 2)), ["x", "y"]) -input_graph = [ - RadiusGraph(x=x[i], pos=pos[i], radius=0.1) for i in range(len(x)) -] -output_graph = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize("automatic_batching", [True, False]) -def test_dataloader_labels(input_, output_, automatic_batching): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, - batch_size=10, - train_size=0.7, - val_size=0.3, - test_size=0.0, - automatic_batching=automatic_batching, - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 7 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - assert isinstance(data["data"]["input"].x, LabelTensor) - assert data["data"]["input"].x.labels == ["u", "v", "w"] - assert data["data"]["input"].pos.labels == ["x", "y"] - else: - assert isinstance(data["data"]["input"], LabelTensor) - assert data["data"]["input"].labels == ["u", "v", "w"] - assert isinstance(data["data"]["target"], LabelTensor) - assert data["data"]["target"].labels == ["u", "v", "w"] - - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 3 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - assert isinstance(data["data"]["input"].x, LabelTensor) - assert data["data"]["input"].x.labels == ["u", "v", "w"] - assert data["data"]["input"].pos.labels == ["x", "y"] - else: - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["input"], LabelTensor) - assert data["data"]["input"].labels == ["u", "v", "w"] - assert isinstance(data["data"]["target"], torch.Tensor) - assert data["data"]["target"].labels == ["u", "v", "w"] - - -def test_get_all_data(): - input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) - target = input - - problem = SupervisedProblem(input, target) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700 - assert torch.isclose( - datamodule.train_dataset.get_all_data()["data"]["input"], input[:700] - ).all() - assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100 - assert torch.isclose( - datamodule.val_dataset.get_all_data()["data"]["input"], input[900:] - ).all() - assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200 - assert torch.isclose( - datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900] - ).all() - - -def test_input_propery_tensor(): - input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) - target = input - - problem = SupervisedProblem(input, target) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - input_ = datamodule.input - assert isinstance(input_, dict) - assert isinstance(input_["train"], dict) - assert isinstance(input_["val"], dict) - assert isinstance(input_["test"], dict) - assert torch.isclose(input_["train"]["data"], input[:700]).all() - assert torch.isclose(input_["val"]["data"], input[900:]).all() - assert torch.isclose(input_["test"]["data"], input[700:900]).all() - - -def test_input_propery_graph(): - problem = SupervisedProblem(input_graph, output_graph) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - input_ = datamodule.input - assert isinstance(input_, dict) - assert isinstance(input_["train"], dict) - assert isinstance(input_["val"], dict) - assert isinstance(input_["test"], dict) - assert isinstance(input_["train"]["data"], list) - assert isinstance(input_["val"]["data"], list) - assert isinstance(input_["test"]["data"], list) - assert len(input_["train"]["data"]) == 70 - assert len(input_["val"]["data"]) == 10 - assert len(input_["test"]["data"]) == 20 diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py deleted file mode 100644 index 3a63f7ec6..000000000 --- a/tests/test_data/test_graph_dataset.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch -import pytest -from pina.data import PinaDatasetFactory, PinaGraphDataset -from pina.graph import KNNGraph -from torch_geometric.data import Data - -x = torch.rand((100, 20, 10)) -pos = torch.rand((100, 20, 2)) -input_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x, pos) -] -output_ = torch.rand((100, 20, 10)) - -x_2 = torch.rand((50, 20, 10)) -pos_2 = torch.rand((50, 20, 2)) -input_2_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x_2, pos_2) -] -output_2_ = torch.rand((50, 20, 10)) - - -# Problem with a single condition -conditions_dict_single = { - "data": { - "input": input_, - "target": output_, - } -} -max_conditions_lengths_single = {"data": 100} - -# Problem with multiple conditions -conditions_dict_multi = { - "data_1": { - "input": input_, - "target": output_, - }, - "data_2": { - "input": input_2_, - "target": output_2_, - }, -} - -max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} - - -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_multi, max_conditions_lengths_multi), - ], -) -def test_constructor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - assert isinstance(dataset, PinaGraphDataset) - assert len(dataset) == 100 - - -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_multi, max_conditions_lengths_multi), - ], -) -def test_getitem(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - data = dataset[50] - assert isinstance(data, dict) - assert all([isinstance(d["input"], Data) for d in data.values()]) - assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) - assert all( - [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()] - ) - assert all( - [d["target"].shape == torch.Size((20, 10)) for d in data.values()] - ) - assert all( - [ - d["input"].edge_index.shape == torch.Size((2, 60)) - for d in data.values() - ] - ) - assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()]) - - data = dataset.fetch_from_idx_list([i for i in range(20)]) - assert isinstance(data, dict) - assert all([isinstance(d["input"], Data) for d in data.values()]) - assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) - assert all( - [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] - ) - assert all( - [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()] - ) - assert all( - [ - d["input"].edge_index.shape == torch.Size((2, 1200)) - for d in data.values() - ] - ) - assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) - - -def test_input_single_condition(): - dataset = PinaDatasetFactory( - conditions_dict_single, - max_conditions_lengths=max_conditions_lengths_single, - automatic_batching=True, - ) - input_ = dataset.input - assert isinstance(input_, dict) - assert isinstance(input_["data"], list) - assert all([isinstance(d, Data) for d in input_["data"]]) - - -def test_input_multi_condition(): - dataset = PinaDatasetFactory( - conditions_dict_multi, - max_conditions_lengths=max_conditions_lengths_multi, - automatic_batching=True, - ) - input_ = dataset.input - assert isinstance(input_, dict) - assert isinstance(input_["data_1"], list) - assert all([isinstance(d, Data) for d in input_["data_1"]]) - assert isinstance(input_["data_2"], list) - assert all([isinstance(d, Data) for d in input_["data_2"]]) diff --git a/tests/test_data/test_tensor_dataset.py b/tests/test_data/test_tensor_dataset.py deleted file mode 100644 index 9e348c942..000000000 --- a/tests/test_data/test_tensor_dataset.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import pytest -from pina.data import PinaDatasetFactory, PinaTensorDataset - -input_tensor = torch.rand((100, 10)) -output_tensor = torch.rand((100, 2)) - -input_tensor_2 = torch.rand((50, 10)) -output_tensor_2 = torch.rand((50, 2)) - -conditions_dict_single = { - "data": { - "input": input_tensor, - "target": output_tensor, - } -} - -conditions_dict_single_multi = { - "data_1": { - "input": input_tensor, - "target": output_tensor, - }, - "data_2": { - "input": input_tensor_2, - "target": output_tensor_2, - }, -} - -max_conditions_lengths_single = {"data": 100} - -max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} - - -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi), - ], -) -def test_constructor_tensor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - assert isinstance(dataset, PinaTensorDataset) - - -def test_getitem_single(): - dataset = PinaDatasetFactory( - conditions_dict_single, - max_conditions_lengths=max_conditions_lengths_single, - automatic_batching=False, - ) - - tensors = dataset.fetch_from_idx_list([i for i in range(70)]) - assert isinstance(tensors, dict) - assert list(tensors.keys()) == ["data"] - assert sorted(list(tensors["data"].keys())) == ["input", "target"] - assert isinstance(tensors["data"]["input"], torch.Tensor) - assert tensors["data"]["input"].shape == torch.Size((70, 10)) - assert isinstance(tensors["data"]["target"], torch.Tensor) - assert tensors["data"]["target"].shape == torch.Size((70, 2)) - - -def test_getitem_multi(): - dataset = PinaDatasetFactory( - conditions_dict_single_multi, - max_conditions_lengths=max_conditions_lengths_multi, - automatic_batching=False, - ) - tensors = dataset.fetch_from_idx_list([i for i in range(70)]) - assert isinstance(tensors, dict) - assert list(tensors.keys()) == ["data_1", "data_2"] - assert sorted(list(tensors["data_1"].keys())) == ["input", "target"] - assert isinstance(tensors["data_1"]["input"], torch.Tensor) - assert tensors["data_1"]["input"].shape == torch.Size((70, 10)) - assert isinstance(tensors["data_1"]["target"], torch.Tensor) - assert tensors["data_1"]["target"].shape == torch.Size((70, 2)) - - assert sorted(list(tensors["data_2"].keys())) == ["input", "target"] - assert isinstance(tensors["data_2"]["input"], torch.Tensor) - assert tensors["data_2"]["input"].shape == torch.Size((50, 10)) - assert isinstance(tensors["data_2"]["target"], torch.Tensor) - assert tensors["data_2"]["target"].shape == torch.Size((50, 2)) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py new file mode 100644 index 000000000..8419a68f2 --- /dev/null +++ b/tests/test_datamodule.py @@ -0,0 +1,318 @@ +import torch +import pytest +from pina.data import PinaDataModule + +# from pina.data import PinaTensorDataset, PinaGraphDataset +from pina.problem.zoo import SupervisedProblem +from pina.graph import RadiusGraph + +# from pina.data import DummyDataloader +from pina._src.data.data_module import _ConditionSubset +from pina import Trainer +from pina.solver import SupervisedSolver +from torch_geometric.data import Batch +from torch.utils.data import DataLoader +from pina.problem.zoo import Poisson2DSquareProblem +from pina._src.data.aggregator import _Aggregator +from pina.solver import PINN + + +def _create_tensor_data(): + input_tensor = torch.rand((100, 10)) + output_tensor = torch.rand((100, 2)) + return input_tensor, output_tensor + + +def _create_graph_data(): + x = torch.rand((100, 50, 10)) + pos = torch.rand((100, 50, 2)) + input_graph = [ + RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) + ] + output_graph = torch.rand((100, 50, 2)) + return input_graph, output_graph + + +def test_init_tensor(): + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + dm = PinaDataModule(problem) + assert dm.problem == problem + assert dm.trainer is None + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == {"data"} + assert isinstance(dm.split_idxs["data"], dict) + assert set(dm.split_idxs["data"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["data"]["train"], list) + assert isinstance(dm.split_idxs["data"]["val"], list) + assert isinstance(dm.split_idxs["data"]["test"], list) + assert len(dm.split_idxs["data"]["train"]) == 70 + assert len(dm.split_idxs["data"]["val"]) == 10 + assert len(dm.split_idxs["data"]["test"]) == 20 + + +def test_init_graph(): + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + dm = PinaDataModule(problem) + assert dm.problem == problem + assert dm.trainer is None + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == {"data"} + assert isinstance(dm.split_idxs["data"], dict) + assert set(dm.split_idxs["data"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["data"]["train"], list) + assert isinstance(dm.split_idxs["data"]["val"], list) + assert isinstance(dm.split_idxs["data"]["test"], list) + assert len(dm.split_idxs["data"]["train"]) == 70 + assert len(dm.split_idxs["data"]["val"]) == 10 + assert len(dm.split_idxs["data"]["test"]) == 20 + + +def test_init_poisson(): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + dm = PinaDataModule(problem) + assert dm.problem == problem + assert dm.trainer is None + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == {"D", "boundary"} + assert isinstance(dm.split_idxs["D"], dict) + assert set(dm.split_idxs["D"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["D"]["train"], list) + assert isinstance(dm.split_idxs["D"]["val"], list) + assert isinstance(dm.split_idxs["D"]["test"], list) + assert len(dm.split_idxs["D"]["train"]) == 70 + assert len(dm.split_idxs["D"]["val"]) == 10 + assert len(dm.split_idxs["D"]["test"]) == 20 + + assert isinstance(dm.split_idxs["boundary"], dict) + assert set(dm.split_idxs["boundary"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["boundary"]["train"], list) + assert isinstance(dm.split_idxs["boundary"]["val"], list) + assert isinstance(dm.split_idxs["boundary"]["test"], list) + assert len(dm.split_idxs["boundary"]["train"]) == 7 + assert len(dm.split_idxs["boundary"]["val"]) == 1 + assert len(dm.split_idxs["boundary"]["test"]) == 2 + + +def test_setup_tensor(): + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + dm = PinaDataModule(problem) + dm.setup() + assert hasattr(dm, "train_datasets") + assert isinstance(dm.train_datasets, dict) + assert set(dm.train_datasets.keys()) == {"data"} + assert isinstance(dm.train_datasets["data"], _ConditionSubset) + assert hasattr(dm, "val_datasets") + assert isinstance(dm.val_datasets, dict) + assert set(dm.val_datasets.keys()) == {"data"} + assert isinstance(dm.val_datasets["data"], _ConditionSubset) + assert hasattr(dm, "test_datasets") + assert isinstance(dm.test_datasets, dict) + assert set(dm.test_datasets.keys()) == {"data"} + assert isinstance(dm.test_datasets["data"], _ConditionSubset) + + +def test_setup_graph(): + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + dm = PinaDataModule(problem) + dm.setup() + assert hasattr(dm, "train_datasets") + assert isinstance(dm.train_datasets, dict) + assert set(dm.train_datasets.keys()) == {"data"} + assert isinstance(dm.train_datasets["data"], _ConditionSubset) + assert hasattr(dm, "val_datasets") + assert isinstance(dm.val_datasets, dict) + assert set(dm.val_datasets.keys()) == {"data"} + assert isinstance(dm.val_datasets["data"], _ConditionSubset) + assert hasattr(dm, "test_datasets") + assert isinstance(dm.test_datasets, dict) + assert set(dm.test_datasets.keys()) == {"data"} + assert isinstance(dm.test_datasets["data"], _ConditionSubset) + + +def test_setup_poisson(): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + dm = PinaDataModule(problem) + dm.setup() + assert hasattr(dm, "train_datasets") + assert isinstance(dm.train_datasets, dict) + assert set(dm.train_datasets.keys()) == {"D", "boundary"} + assert isinstance(dm.train_datasets["D"], _ConditionSubset) + assert isinstance(dm.train_datasets["boundary"], _ConditionSubset) + assert hasattr(dm, "val_datasets") + assert isinstance(dm.val_datasets, dict) + assert set(dm.val_datasets.keys()) == {"D", "boundary"} + assert isinstance(dm.val_datasets["D"], _ConditionSubset) + assert isinstance(dm.val_datasets["boundary"], _ConditionSubset) + assert hasattr(dm, "test_datasets") + assert isinstance(dm.test_datasets, dict) + assert set(dm.test_datasets.keys()) == {"D", "boundary"} + assert isinstance(dm.test_datasets["D"], _ConditionSubset) + assert isinstance(dm.test_datasets["boundary"], _ConditionSubset) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +def test_dataloader_tensor(batch_size): + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + trainer = Trainer( + solver=SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)), + batch_size=batch_size, + train_size=0.7, + val_size=0.2, + test_size=0.1, + ) + dm = trainer.data_module + dm.setup() + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 70 + ) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 10 + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +def test_dataloader_graph(batch_size): + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + trainer = Trainer( + solver=SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)), + train_size=0.7, + val_size=0.2, + test_size=0.1, + batch_size=batch_size, + ) + dm = trainer.data_module + dm.setup() + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], Batch) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 70 + ) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], Batch) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 10 + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +def test_dataloader_poisson_cbs(batch_size): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + trainer = Trainer( + solver=PINN(problem=problem, model=torch.nn.Linear(10, 10)), + batch_size=batch_size, + val_size=0.1, + test_size=0.2, + train_size=0.7, + batching_mode="common_batch_size", + ) + dm = trainer.data_module + dm.setup() + + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert ( + len(data["D"]["input"]) == batch_size if batch_size is not None else 70 + ) + assert ( + len(data["boundary"]["input"]) == min(batch_size, 7) + if batch_size is not None + else 7 + ) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert ( + len(data["D"]["input"]) == min(batch_size, 10) + if batch_size is not None + else 10 + ) + assert ( + len(data["boundary"]["input"]) == min(batch_size, 1) + if batch_size is not None + else 1 + ) + + +@pytest.mark.parametrize("batch_size", [None, 5, 20]) +def test_dataloader_poisson_proportional(batch_size): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + trainer = Trainer( + solver=PINN(problem=problem, model=torch.nn.Linear(10, 10)), + batch_size=batch_size, + val_size=0.1, + test_size=0.2, + train_size=0.7, + batching_mode="proportional", + ) + dm = trainer.data_module + dm.setup() + + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert ( + len(data["D"]["input"]) == batch_size - 1 + if batch_size is not None + else 70 + ) + assert len(data["boundary"]["input"]) == 1 if batch_size is not None else 7 From 85514299320ecdb408963029e70c360558e684ee Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Wed, 3 Dec 2025 17:57:18 +0100 Subject: [PATCH 3/3] implement autoregressive solver Co-authored-by: GiovanniCanali --- docs/source/_rst/_code.rst | 2 + .../autoregressive_solver.rst | 7 + .../autoregressive_solver_interface.rst | 7 + pina/_src/data/creator.py | 21 +- pina/_src/problem/abstract_problem.py | 30 +- .../solver/autoregressive_solver/__init__.py | 0 .../autoregressive_solver.py | 398 ++++++++++++++++++ .../autoregressive_solver_interface.py | 82 ++++ pina/solver/__init__.py | 7 + tests/conftest.py | 17 + tests/test_data_manager.py | 4 +- tests/test_problem.py | 38 -- .../test_solver/test_autoregressive_solver.py | 203 +++++++++ tests/test_solver/test_competitive_pinn.py | 10 +- tests/test_solver/test_ensemble_pinn.py | 9 +- .../test_ensemble_supervised_solver.py | 9 +- tests/test_solver/test_garom.py | 10 +- tests/test_solver/test_gradient_pinn.py | 10 +- tests/test_solver/test_pinn.py | 13 +- tests/test_solver/test_rba_pinn.py | 10 +- tests/test_solver/test_supervised_solver.py | 28 +- 21 files changed, 792 insertions(+), 123 deletions(-) create mode 100644 docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst create mode 100644 docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst create mode 100644 pina/_src/solver/autoregressive_solver/__init__.py create mode 100644 pina/_src/solver/autoregressive_solver/autoregressive_solver.py create mode 100644 pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py create mode 100644 tests/conftest.py create mode 100644 tests/test_solver/test_autoregressive_solver.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 64d88bc8b..7d992d1ca 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -82,6 +82,8 @@ Solvers DeepEnsembleSupervisedSolver ReducedOrderModelSolver GAROM + AutoregressiveSolverInterface + AutoregressiveSolver Models diff --git a/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst new file mode 100644 index 000000000..4cde8d1b9 --- /dev/null +++ b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst @@ -0,0 +1,7 @@ +Autoregressive Solver +====================== +.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver + +.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver.AutoregressiveSolver + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst new file mode 100644 index 000000000..516409bd1 --- /dev/null +++ b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst @@ -0,0 +1,7 @@ +Autoregressive Solver Interface +================================= +.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver_interface + +.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver_interface.AutoregressiveSolverInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py index 0e84aef72..90b6d93fa 100644 --- a/pina/_src/data/creator.py +++ b/pina/_src/data/creator.py @@ -79,13 +79,16 @@ def _compute_batch_sizes(self, datasets): """ batch_sizes = {} if self.batching_mode == "common_batch_size": + + if self.batch_size is None: + batch_size = max( + dataset.length for dataset in datasets.values() + ) + else: + batch_size = self.batch_size + for name in datasets.keys(): - if self.batch_size is None: - batch_sizes[name] = len(datasets[name]) - else: - batch_sizes[name] = min( - self.batch_size, len(datasets[name]) - ) + batch_sizes[name] = min(batch_size, len(datasets[name])) return batch_sizes if self.batching_mode == "proportional": return self._compute_proportional_batch_sizes(datasets) @@ -168,8 +171,12 @@ def __call__(self, datasets): dataloaders = {} if self.batching_mode == "common_batch_size": max_len = max(len(dataset) for dataset in datasets.values()) + print(batch_sizes) for name, dataset in datasets.items(): - if self.batching_mode == "common_batch_size": + if ( + self.batching_mode == "common_batch_size" + and dataset.length != batch_sizes[name] + ): dataset.max_len = max_len dataloaders[name] = self.conditions[name].create_dataloader( dataset=dataset, diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index cc2b9e042..5dbba18c2 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -43,21 +43,21 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - # back compatibility 0.1 - @property - def input_pts(self): - """ - Return a dictionary mapping condition names to their corresponding - input points. If some domains are not sampled, they will not be returned - and the corresponding condition will be empty. - - :return: The input points of the problem. - :rtype: dict - """ - to_return = {} - for cond_name, data in self.collected_data.items(): - to_return[cond_name] = data["input"] - return to_return + # # back compatibility 0.1 + # @property + # def input_pts(self): + # """ + # Return a dictionary mapping condition names to their corresponding + # input points. If some domains are not sampled, they will not be returned + # and the corresponding condition will be empty. + + # :return: The input points of the problem. + # :rtype: dict + # """ + # to_return = {} + # for cond_name, data in self.collected_data.items(): + # to_return[cond_name] = data["input"] + # return to_return @property def discretised_domains(self): diff --git a/pina/_src/solver/autoregressive_solver/__init__.py b/pina/_src/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..e0b92af3d --- /dev/null +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,398 @@ +import torch +from pina._src.solver.autoregressive_solver.autoregressive_solver_interface import ( + AutoregressiveSolverInterface, +) +from pina._src.solver.solver import SingleSolverInterface +from pina._src.loss.loss_interface import LossInterface +from pina._src.core.utils import check_consistency + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + r""" + The autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights that down-weight later predictions + to stabilize training. + """ + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + reset_weights_at_epoch_start=True, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The neural network model to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is ``None``. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. Default is ``False``. + :param bool reset_weights_at_epoch_start: If ``True``, the running + averages used for adaptive weighting are reset at the start of each + epoch. Setting this parameter to ``False`` can improve training + stability, especially when data are scarce. Default is ``True``. + :raise ValueError: If the provided loss function is not compatible. + :raise ValueError: If ``reset_weights_at_epoch_start`` is not a boolean. + """ + super().__init__( + problem=problem, + model=model, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + # Check consistency + loss = loss or torch.nn.MSELoss() + check_consistency( + loss, (LossInterface, torch.nn.modules.loss._Loss), subclass=False + ) + check_consistency(reset_weights_at_epoch_start, bool) + + # Initialization + self._loss_fn = loss + self.reset_weights_at_epoch_start = reset_weights_at_epoch_start + self._running_avg = {} + self._step_count = {} + + def on_train_epoch_start(self): + """ + Clean up running averages at the start of each epoch if + ``reset_weights_at_epoch_start`` is True. + """ + if self.reset_weights_at_epoch_start: + self._running_avg.clear() + self._step_count.clear() + + def optimization_cycle(self, batch): + """ + The optimization cycle for autoregressive solvers. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :return: The losses computed for all conditions in the batch. + :rtype: dict + """ + # Store losses for each condition in the batch + condition_loss = {} + + # Loop through each condition and compute the autoregressive loss + for condition_name, points in batch: + # TODO: remove setting once AutoregressiveCondition is implemented + # TODO: pass a temporal weighting schema in the __init__ + if hasattr(self.problem.conditions[condition_name], "settings"): + settings = self.problem.conditions[condition_name].settings + eps = settings.get("eps", None) + kwargs = settings.get("kwargs", {}) + else: + eps = None + kwargs = {} + + loss = self.loss_autoregressive( + points["input"], + condition_name=condition_name, + eps=eps, + **kwargs, + ) + condition_loss[condition_name] = loss + return condition_loss + + def loss_autoregressive( + self, + input, + condition_name, + eps=None, + aggregation_strategy=None, + **kwargs, + ): + """ + Compute the loss for each autoregressive condition. + + :param input: The input tensor containing unroll windows. + :type input: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for loss computation. + :raise ValueError: If ``input`` has less than 4 dimensions. + :return: The scalar loss value for the given batch. + :rtype: torch.Tensor | LabelTensor + """ + # Check input dimensionality + if input.dim() < 4: + raise ValueError( + "The provided input tensor must have at least 4 dimensions:" + " [trajectories, windows, time_steps, *features]." + f" Got shape {input.shape}." + ) + + # Initialize current state and loss list + current_state = input[:, :, 0] + losses = [] + + # Iterate through the unroll window and compute the loss for each step + for step in range(1, input.shape[2]): + + # Predict + processed_input = self.preprocess_step(current_state, **kwargs) + output = self.forward(processed_input) + predicted_state = self.postprocess_step(output, **kwargs) + + # Compute step loss + target_state = input[:, :, step] + step_loss = self._loss_fn(predicted_state, target_state, **kwargs) + losses.append(step_loss) + + # Update current state for the next step + current_state = predicted_state + + # Stack step losses into a tensor of shape [time_steps - 1] + step_losses = torch.stack(losses).as_subclass(torch.Tensor) + + # Compute adaptive weights based on running averages of step losses + with torch.no_grad(): + condition_name = condition_name or "default" + weights = self._get_weights(condition_name, step_losses, eps) + + # Aggregate the weighted step losses into a single scalar loss value + if aggregation_strategy is None: + aggregation_strategy = torch.mean + + return aggregation_strategy(step_losses * weights) + + def preprocess_step(self, current_state, **kwargs): + """ + Pre-process the current state before passing it to the model's forward. + + :param current_state: The current state to be preprocessed. + :type current_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for pre-processing. + :return: The preprocessed state for the given step. + :rtype: torch.Tensor | LabelTensor + """ + return current_state + + def postprocess_step(self, predicted_state, **kwargs): + """ + Post-process the state predicted by the model. + + :param predicted_state: The predicted state tensor from the model. + :type predicted_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for post-processing. + :return: The post-processed predicted state tensor. + :rtype: torch.Tensor | LabelTensor + """ + return predicted_state + + def _get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + + :param str condition_name: The name of the current condition. + :param torch.Tensor step_losses: The tensor of per-step losses. + :param float eps: The weighting parameter. + :return: The weights tensor. + :rtype: torch.Tensor + """ + # Determine the key for caching based on the condition name + key = condition_name or "default" + + # Initialize the key if not in the running averages. + if key not in self._running_avg: + self._running_avg[key] = step_losses.detach().clone() + self._step_count[key] = 1 + + # Update running averages and counts + else: + self._step_count[key] += 1 + value = step_losses.detach() - self._running_avg[key] + self._running_avg[key] += value / self._step_count[key] + + return self._compute_adaptive_weights(self._running_avg[key], eps) + + def _compute_adaptive_weights(self, step_losses, eps): + """ + Compute temporal adaptive weights. + + :param torch.Tensor step_losses: The tensor of per-step losses. + :param float eps: The weighting parameter. + :return: The weights tensor. + :rtype: torch.Tensor + """ + # If eps is None, return uniform weights + if eps is None: + return torch.ones_like(step_losses) + + # Compute cumulative loss and apply exponential weighting + cumulative_loss = -eps * torch.cumsum(step_losses, dim=0) + + return torch.exp(cumulative_loss) + + def predict(self, initial_state, n_steps, **kwargs): + """ + Generate predictions by recursively calling the model's forward. + + :param initial_state: The initial state from which to start prediction. + The initial state must be of shape ``[trajectories, 1, *features]``. + :type initial_state: torch.Tensor | LabelTensor + :param int n_steps: The number of autoregressive steps to predict. + :param dict kwargs: Additional keyword arguments. + :raise ValueError: If the provided initial_state tensor has less than 3 + dimensions. + :return: The predicted trajectory, including the initial state. It has + shape ``[trajectories, n_steps + 1, *features]``, where the first + step corresponds to the initial state. + :rtype: torch.Tensor | LabelTensor + """ + # Set model to evaluation mode for prediction + self.eval() + + # Check intial state dimensionality + if initial_state.dim() < 3: + raise ValueError( + "The provided initial_state tensor must have at least 3" + "dimensions: [trajectories, time_steps, *features]." + f" Got shape {initial_state.shape}." + ) + + # Initialize the list of predictions with the initial state + predictions = [initial_state] + + # Generate predictions recursively for n_steps + with torch.no_grad(): + for _ in range(n_steps): + input = self.preprocess_step(predictions[-1], **kwargs) + output = self.forward(input) + next_state = self.postprocess_step(output, **kwargs) + predictions.append(next_state) + + return torch.stack(predictions, dim=2) + + # TODO: integrate in the Autoregressive Condition once implemented + @staticmethod + def unroll(data, unroll_length, n_unrolls=None, randomize=True): + """ + Create unrolling time windows from temporal data. + + This function takes as input a tensor of shape + ``[trajectories, time_steps, *features]`` and produces a tensor of shape + ``[trajectories, windows, unroll_length, *features]``. + Each window contains a sequence of subsequent states used for computing + the multi-step loss during training. + + :param data: The temporal data tensor to be unrolled. + :type data: torch.Tensor | LabelTensor + :param int unroll_length: The number of time steps in each window. + :param int n_unrolls: The maximum number of windows to return. + If ``None``, all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly + permuted before applying ``n_unrolls``. Default is ``True``. + :raise ValueError: If the input ``data`` has less than 3 dimensions. + :raise ValueError: If ``unroll_length`` is greater or equal to the + number of time steps in ``data``. + :return: A tensor of unrolled windows. + :rtype: torch.Tensor | LabelTensor + """ + # Check input dimensionality + if data.dim() < 3: + raise ValueError( + "The provided data tensor must have at least 3 dimensions:" + " [trajectories, time_steps, *features]." + f" Got shape {data.shape}." + ) + + # Determine valid starting indices for unroll windows + start_idx = AutoregressiveSolver._get_start_idx( + n_steps=data.shape[1], + unroll_length=unroll_length, + n_unrolls=n_unrolls, + randomize=randomize, + ) + + # Create unroll windows by slicing the data tensor at starting indices + windows = [data[:, s : s + unroll_length] for s in start_idx] + + return torch.stack(windows, dim=1) + + @staticmethod + def _get_start_idx(n_steps, unroll_length, n_unrolls=None, randomize=True): + """ + Determine starting indices for unroll windows. + + :param int n_steps: The total number of time steps in the data. + :param int unroll_length: The number of time steps in each window. + :param int n_unrolls: The maximum number of windows to return. + If ``None``, all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly + permuted before applying ``n_unrolls``. Default is ``True``. + :raise ValueError: If ``unroll_length`` is greater or equal to the + number of time steps in ``data``. + :return: A tensor of starting indices for unroll windows. + :rtype: torch.Tensor + """ + # Calculate the last valid starting index for unroll windows + last_idx = n_steps - unroll_length + + # Raise error if no valid windows can be created + if last_idx < 0: + raise ValueError( + f"Cannot create unroll windows: unroll_length ({unroll_length})" + " cannot be greater or equal to the number of time_steps" + f" ({n_steps})." + ) + + # Generate ordered starting indices for unroll windows + indices = torch.arange(last_idx + 1) + + # Permute indices if randomization is enabled + if randomize: + indices = indices[torch.randperm(len(indices))] + + # Limit the number of windows if n_unrolls is specified + if n_unrolls is not None and n_unrolls < len(indices): + indices = indices[:n_unrolls] + + return indices + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..7029995fd --- /dev/null +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,82 @@ +"""Module for the Autoregressive Solver Interface.""" + +from abc import abstractmethod +from pina._src.condition.data_condition import DataCondition +from pina._src.solver.solver import SolverInterface + + +class AutoregressiveSolverInterface(SolverInterface): + # TODO: fix once the AutoregressiveCondition is implemented. + """ + Abstract interface for all autoregressive solvers. + + Any solver implementing this interface is expected to be designed to learn + dynamical systems in an autoregressive manner. The solver should handle + conditions of type :class:`~pina.condition.data_condition.DataCondition`. + """ + + accepted_conditions_types = (DataCondition,) + + @abstractmethod + def preprocess_step(self, current_state, **kwargs): + """ + Pre-process the current state before passing it to the model's forward. + + :param current_state: The current state to be preprocessed. + :type current_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for pre-processing. + :return: The preprocessed state for the given step. + :rtype: torch.Tensor | LabelTensor + """ + + @abstractmethod + def postprocess_step(self, predicted_state, **kwargs): + """ + Post-process the state predicted by the model. + + :param predicted_state: The predicted state tensor from the model. + :type predicted_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for post-processing. + :return: The post-processed predicted state tensor. + :rtype: torch.Tensor | LabelTensor + """ + + # TODO: remove once the AutoregressiveCondition is implemented. + @abstractmethod + def loss_autoregressive(self, input, **kwargs): + """ + Compute the loss for each autoregressive condition. + + :param input: The input tensor containing unroll windows. + :type input: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for loss computation. + :return: The scalar loss value for the given batch. + :rtype: torch.Tensor | LabelTensor + """ + + @abstractmethod + def predict(self, starting_value, num_steps, **kwargs): + """ + Generate predictions by recursively applying the model. + + :param starting_value: The initial state from which to start prediction. + The initial state must be of shape ``[trajectories, 1, features]``, + where the trajectory dimension can be used for batching. + :type starting_value: torch.Tensor | LabelTensor + :param int num_steps: The number of autoregressive steps to predict. + :param dict kwargs: Additional keyword arguments. + :return: The predicted trajectory, including the initial state. It has + shape ``[trajectories, num_steps + 1, features]``, where the first + step corresponds to the initial state. + :rtype: torch.Tensor | LabelTensor + """ + + @property + @abstractmethod + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index a93914099..619e59d04 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -27,6 +27,8 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", + "AutoregressiveSolverInterface", ] from pina._src.solver.solver import ( @@ -64,3 +66,8 @@ ) from pina._src.solver.garom import GAROM + +from pina._src.solver.autoregressive_solver.autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..5bff82941 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +import shutil +from pathlib import Path +import pytest + + +@pytest.fixture +def clean_tmp_dir(tmp_path): + path = Path(tmp_path) + + if path.exists(): + shutil.rmtree(path) + + path.mkdir(parents=True, exist_ok=True) + yield path + + if path.exists(): + shutil.rmtree(path) diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py index af46c500d..9bab62b57 100644 --- a/tests/test_data_manager.py +++ b/tests/test_data_manager.py @@ -101,7 +101,7 @@ def test_graph_data_create_batch(): data_manager = _DataManager(graph=graph, target=target) item1 = data_manager[0] item2 = data_manager[1] - batch_data = _GraphDataManager._create_batch([item1, item2]) + batch_data = _GraphDataManager.create_batch([item1, item2]) assert hasattr(batch_data, "graph") assert hasattr(batch_data, "target") batched_graphs = batch_data.graph @@ -122,7 +122,7 @@ def test_tensor_data_create_batch(): data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) item1 = data_manager[0] item2 = data_manager[1] - batch_data = _TensorDataManager._create_batch([item1, item2]) + batch_data = _TensorDataManager.create_batch([item1, item2]) assert hasattr(batch_data, "pippo") assert hasattr(batch_data, "pluto") assert hasattr(batch_data, "paperino") diff --git a/tests/test_problem.py b/tests/test_problem.py index bdd6a1d4d..53ee3bc57 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -49,24 +49,6 @@ def test_variables_correct_order_sampling(): ) -def test_input_pts(): - n = 10 - poisson_problem = Poisson() - poisson_problem.discretise_domain(n, "grid") - assert sorted(list(poisson_problem.input_pts.keys())) == sorted( - list(poisson_problem.conditions.keys()) - ) - - -def test_collected_data(): - n = 10 - poisson_problem = Poisson() - poisson_problem.discretise_domain(n, "grid") - assert sorted(list(poisson_problem.collected_data.keys())) == sorted( - list(poisson_problem.conditions.keys()) - ) - - def test_add_points(): poisson_problem = Poisson() poisson_problem.discretise_domain(1, "random", domains=["D"]) @@ -110,23 +92,3 @@ def test_wrong_custom_sampling_logic(mode): # Necessary cleanup if "new" in poisson_problem.domains: del poisson_problem.domains["new"] - - -def test_aggregate_data(): - poisson_problem = Poisson() - poisson_problem.conditions["data"] = Condition( - input=LabelTensor(torch.tensor([[0.0, 1.0]]), labels=["x", "y"]), - target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]), - ) - poisson_problem.discretise_domain(1, "random", domains="all") - poisson_problem.collect_data() - assert isinstance(poisson_problem.collected_data, dict) - for name, conditions in poisson_problem.conditions.items(): - assert name in poisson_problem.collected_data.keys() - if isinstance(conditions, InputTargetCondition): - assert "input" in poisson_problem.collected_data[name].keys() - assert "target" in poisson_problem.collected_data[name].keys() - elif isinstance(conditions, DomainEquationCondition): - assert "input" in poisson_problem.collected_data[name].keys() - assert "target" not in poisson_problem.collected_data[name].keys() - assert "equation" in poisson_problem.collected_data[name].keys() diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..2216be9bf --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,203 @@ +import shutil +import pytest +import torch +from torch._dynamo.eval_frame import OptimizedModule + +from pina import Condition, Trainer, LabelTensor +from pina.solver import AutoregressiveSolver +from pina.condition import DataCondition +from pina.problem import AbstractProblem +from pina.model import FeedForward + + +# Hyperparameters and settings +n_traj = 5 +t_steps = 10 +n_feats = 2 +unroll_length = 3 +n_unrolls = 4 + + +# TODO: test this in AutoregressiveCondition once it's implemented +# Utility function to create synthetic data for testing +def create_data(n_traj, t_steps, n_feats, unroll_length, n_unrolls, use_lt): + + init_state = torch.rand(n_traj, n_feats) + traj = torch.stack([0.95**i * init_state for i in range(t_steps)], dim=1) + + data = AutoregressiveSolver.unroll( + data=traj, + unroll_length=unroll_length, + n_unrolls=n_unrolls, + randomize=True, + ) + labels = [f"feat_{i}" for i in range(n_feats)] + return LabelTensor(data, labels=labels) + + +# Data +data = create_data( + n_traj=n_traj, + t_steps=t_steps, + n_feats=n_feats, + unroll_length=unroll_length, + n_unrolls=n_unrolls, + use_lt=True, +) + + +# Problem +class Problem(AbstractProblem): + + input_variables = [f"feat_{i}" for i in range(n_feats)] + output_variables = [f"feat_{i}" for i in range(n_feats)] + conditions = {} + + def __init__(self, data): + super().__init__() + self.data = data + self.conditions = {"autoregressive": Condition(input=self.data)} + self.conditions_settings = { + "autoregressive": {"eps": 0.1} + } # TODO: remove once the autoregressive condition is implemented + + +problem = Problem(data) +model = FeedForward(n_feats, n_feats, 128, 2) + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_constructor(use_lt, bool_value): + + solver = AutoregressiveSolver( + problem=problem, + model=model, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + + assert solver.accepted_conditions_types == ( + DataCondition, + ) # TODO: update once the AutoregressiveCondition is implemented + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 1, 2, 5]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_solver_train(use_lt, batch_size, compile, bool_value): + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + val_size=0.0, + test_size=0.0, + compile=compile, + ) + trainer.train() + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 1, 2, 5]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_solver_validation(use_lt, batch_size, compile, bool_value): + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.9, + val_size=0.1, + test_size=0.0, + compile=compile, + ) + trainer.train() + if trainer.compile: + assert isinstance(solver.model, OptimizedModule) + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 1, 2, 5]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_solver_test(use_lt, batch_size, compile, bool_value): + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.6, + val_size=0.2, + test_size=0.2, + compile=compile, + ) + trainer.test() + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_train_load_restore(use_lt): + dir = "tests/test_solver/tmp" + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=False, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.7, + val_size=0.2, + test_size=0.1, + default_root_dir=dir, + ) + trainer.train() + + # restore + new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + new_trainer.train( + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/" + + "epoch=4-step=5.ckpt" + ) + + # loading + new_solver = AutoregressiveSolver.load_from_checkpoint( + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", + problem=problem, + model=model, + ) + + test_pts = LabelTensor( + torch.rand(n_traj, t_steps, n_feats), problem.input_variables + ) + assert new_solver.forward(test_pts).shape == (n_traj, t_steps, n_feats) + assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape + torch.testing.assert_close( + new_solver.forward(test_pts), solver.forward(test_pts) + ) + + shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_competitive_pinn.py b/tests/test_solver/test_competitive_pinn.py index 67902197a..b44a5c6ce 100644 --- a/tests/test_solver/test_competitive_pinn.py +++ b/tests/test_solver/test_competitive_pinn.py @@ -113,9 +113,8 @@ def test_solver_test(problem, batch_size, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) -def test_train_load_restore(problem): - dir = "tests/test_solver/tmp" - problem = problem +def test_train_load_restore(clean_tmp_dir, problem): + dir = clean_tmp_dir solver = CompPINN(problem=problem, model=model) trainer = Trainer( solver=solver, @@ -151,8 +150,3 @@ def test_train_load_restore(problem): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_ensemble_pinn.py b/tests/test_solver/test_ensemble_pinn.py index e34ad3643..8d76ee553 100644 --- a/tests/test_solver/test_ensemble_pinn.py +++ b/tests/test_solver/test_ensemble_pinn.py @@ -107,8 +107,8 @@ def test_solver_test(batch_size, compile): ) -def test_train_load_restore(): - dir = "tests/test_solver/tmp" +def test_train_load_restore(clean_tmp_dir): + dir = clean_tmp_dir solver = DeepEnsemblePINN(models=models, problem=problem) trainer = Trainer( solver=solver, @@ -141,8 +141,3 @@ def test_train_load_restore(): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py index 4be2897d9..71c78690f 100644 --- a/tests/test_solver/test_ensemble_supervised_solver.py +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -235,8 +235,8 @@ def test_solver_test_graph(batch_size, use_lt): trainer.test() -def test_train_load_restore(): - dir = "tests/test_solver/tmp/" +def test_train_load_restore(clean_tmp_dir): + dir = clean_tmp_dir problem = LabelTensorProblem() solver = DeepEnsembleSupervisedSolver(problem=problem, models=models) trainer = Trainer( @@ -270,8 +270,3 @@ def test_train_load_restore(): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_garom.py b/tests/test_solver/test_garom.py index 62575825c..1c09b01b7 100644 --- a/tests/test_solver/test_garom.py +++ b/tests/test_solver/test_garom.py @@ -163,9 +163,8 @@ def test_solver_test(batch_size, compile): ) -def test_train_load_restore(): - dir = "tests/test_solver/tmp/" - problem = TensorProblem() +def test_train_load_restore(clean_tmp_dir): + dir = clean_tmp_dir solver = GAROM( problem=TensorProblem(), generator=Generator(), @@ -201,8 +200,3 @@ def test_train_load_restore(): test_pts = torch.rand(20, 1) assert new_solver.forward(test_pts).shape == (20, 2) assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_gradient_pinn.py b/tests/test_solver/test_gradient_pinn.py index c28fc347e..43f70060a 100644 --- a/tests/test_solver/test_gradient_pinn.py +++ b/tests/test_solver/test_gradient_pinn.py @@ -119,9 +119,8 @@ def test_solver_test(problem, batch_size, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) -def test_train_load_restore(problem): - dir = "tests/test_solver/tmp" - problem = problem +def test_train_load_restore(clean_tmp_dir, problem): + dir = clean_tmp_dir solver = GradientPINN(model=model, problem=problem) trainer = Trainer( solver=solver, @@ -157,8 +156,3 @@ def test_train_load_restore(problem): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_pinn.py b/tests/test_solver/test_pinn.py index 76094b473..4630a44f4 100644 --- a/tests/test_solver/test_pinn.py +++ b/tests/test_solver/test_pinn.py @@ -101,9 +101,8 @@ def test_solver_test(problem, batch_size, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) -def test_train_load_restore(problem): - dir = "tests/test_solver/tmp" - problem = problem +def test_train_load_restore(clean_tmp_dir, problem): + dir = clean_tmp_dir solver = PINN(model=model, problem=problem) trainer = Trainer( solver=solver, @@ -116,6 +115,9 @@ def test_train_load_restore(problem): default_root_dir=dir, ) trainer.train() + import os + + print(os.listdir(f"{dir}/lightning_logs/version_0/checkpoints/")) # restore new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") @@ -137,8 +139,3 @@ def test_train_load_restore(problem): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_rba_pinn.py b/tests/test_solver/test_rba_pinn.py index b464f3a7c..8f9165fdf 100644 --- a/tests/test_solver/test_rba_pinn.py +++ b/tests/test_solver/test_rba_pinn.py @@ -122,9 +122,8 @@ def test_solver_test(problem, batch_size, loss, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) -def test_train_load_restore(problem): - dir = "tests/test_solver/tmp" - problem = problem +def test_train_load_restore(clean_tmp_dir, problem): + dir = clean_tmp_dir solver = RBAPINN(model=model, problem=problem) trainer = Trainer( solver=solver, @@ -160,8 +159,3 @@ def test_train_load_restore(problem): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 461130a6b..c39e6034e 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -212,8 +212,27 @@ def test_solver_test_graph(batch_size, use_lt): trainer.test() -def test_train_load_restore(): - dir = "tests/test_solver/tmp/" +import shutil +from pathlib import Path +import pytest + + +@pytest.fixture +def clean_tmp_dir(): + path = Path("tests/test_solver/tmp/") + + if path.exists(): + shutil.rmtree(path) + + path.mkdir(parents=True, exist_ok=True) + yield path + + if path.exists(): + shutil.rmtree(path) + + +def test_train_load_restore(clean_tmp_dir): + dir = clean_tmp_dir problem = LabelTensorProblem() solver = SupervisedSolver(problem=problem, model=model) trainer = Trainer( @@ -248,8 +267,3 @@ def test_train_load_restore(): torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts) ) - - # rm directories - import shutil - - shutil.rmtree("tests/test_solver/tmp")