From d33683a53c313c1f3c928f454a9a727eaa48cdf1 Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Tue, 21 Apr 2026 18:12:11 +0200 Subject: [PATCH 1/3] new internal structure for conditions and data managers --- docs/source/_rst/_code.rst | 26 +- docs/source/_rst/condition/base_condition.rst | 9 + docs/source/_rst/condition/batch_manager.rst | 9 + docs/source/_rst/condition/condition.rst | 6 +- .../_rst/condition/condition_interface.rst | 4 +- docs/source/_rst/condition/data_condition.rst | 4 +- docs/source/_rst/condition/data_manager.rst | 9 + .../_rst/condition/data_manager_interface.rst | 9 + .../condition/domain_equation_condition.rst | 2 + .../_rst/condition/graph_data_manager.rst | 9 + .../condition/input_equation_condition.rst | 2 + .../_rst/condition/input_target_condition.rst | 2 + .../_rst/condition/tensor_data_manager.rst | 9 + pina/_src/condition/base_condition.py | 153 ++++++++ pina/_src/condition/batch_manager.py | 20 +- pina/_src/condition/condition.py | 15 +- pina/_src/condition/condition_base.py | 148 -------- pina/_src/condition/condition_interface.py | 97 +++-- pina/_src/condition/data_condition.py | 79 ++-- pina/_src/condition/data_manager.py | 351 ++---------------- pina/_src/condition/data_manager_interface.py | 53 +++ .../condition/domain_equation_condition.py | 114 +++--- pina/_src/condition/graph_data_manager.py | 246 ++++++++++++ .../condition/input_equation_condition.py | 80 ++-- pina/_src/condition/input_target_condition.py | 84 ++--- pina/_src/condition/tensor_data_manager.py | 110 ++++++ pina/condition/__init__.py | 16 +- 27 files changed, 966 insertions(+), 700 deletions(-) create mode 100644 docs/source/_rst/condition/base_condition.rst create mode 100644 docs/source/_rst/condition/batch_manager.rst create mode 100644 docs/source/_rst/condition/data_manager.rst create mode 100644 docs/source/_rst/condition/data_manager_interface.rst create mode 100644 docs/source/_rst/condition/graph_data_manager.rst create mode 100644 docs/source/_rst/condition/tensor_data_manager.rst create mode 100644 pina/_src/condition/base_condition.py delete mode 100644 pina/_src/condition/condition_base.py create mode 100644 pina/_src/condition/data_manager_interface.py create mode 100644 pina/_src/condition/graph_data_manager.py create mode 100644 pina/_src/condition/tensor_data_manager.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 211398d9d..7433ab5a1 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -52,12 +52,24 @@ Conditions .. toctree:: :titlesonly: - ConditionInterface + Condition Interface + Base Condition Condition - DataCondition - DomainEquationCondition - InputEquationCondition - InputTargetCondition + Data Condition + Domain Equation Condition + Input Equation Condition + Input Target Condition + +Batch and Data Managers +-------------------------- +.. toctree:: + :titlesonly: + + Batch Manager + Data Manager Interface + Data Manager + Graph Data Manager + Tensor Data Manager Solvers -------------- @@ -203,7 +215,7 @@ Equations and Differential Operators Differential Operators -Equations Zoo +Equation Zoo --------------------------------------- .. toctree:: @@ -234,7 +246,7 @@ Problems SpatialProblem TimeDependentProblem -Problems Zoo +Problem Zoo -------------- .. toctree:: diff --git a/docs/source/_rst/condition/base_condition.rst b/docs/source/_rst/condition/base_condition.rst new file mode 100644 index 000000000..2ba4113bd --- /dev/null +++ b/docs/source/_rst/condition/base_condition.rst @@ -0,0 +1,9 @@ +Base Condition +================ +.. currentmodule:: pina.condition.base_condition + +.. automodule:: pina._src.condition.base_condition + +.. autoclass:: pina._src.condition.base_condition.BaseCondition + :members: + :show-inheritance: diff --git a/docs/source/_rst/condition/batch_manager.rst b/docs/source/_rst/condition/batch_manager.rst new file mode 100644 index 000000000..f651260bf --- /dev/null +++ b/docs/source/_rst/condition/batch_manager.rst @@ -0,0 +1,9 @@ +Batch Manager +====================== +.. currentmodule:: pina.condition.batch_manager + +.. automodule:: pina._src.condition.batch_manager + +.. autoclass:: pina._src.condition.batch_manager._BatchManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/condition.rst b/docs/source/_rst/condition/condition.rst index cea9371f7..0f8070506 100644 --- a/docs/source/_rst/condition/condition.rst +++ b/docs/source/_rst/condition/condition.rst @@ -1,7 +1,9 @@ -Conditions +Condition ============= .. currentmodule:: pina.condition.condition +.. automodule:: pina._src.condition.condition + .. autoclass:: pina._src.condition.condition.Condition :members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/source/_rst/condition/condition_interface.rst b/docs/source/_rst/condition/condition_interface.rst index 6c675c275..a81de1afa 100644 --- a/docs/source/_rst/condition/condition_interface.rst +++ b/docs/source/_rst/condition/condition_interface.rst @@ -1,7 +1,9 @@ -ConditionInterface +Condition Interface ====================== .. currentmodule:: pina.condition.condition_interface +.. automodule:: pina._src.condition.condition_interface + .. autoclass:: pina._src.condition.condition_interface.ConditionInterface :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/data_condition.rst b/docs/source/_rst/condition/data_condition.rst index e9f2baab2..d614fbb7b 100644 --- a/docs/source/_rst/condition/data_condition.rst +++ b/docs/source/_rst/condition/data_condition.rst @@ -1,7 +1,9 @@ -Data Conditions +Data Condition ================== .. currentmodule:: pina.condition.data_condition +.. automodule:: pina._src.condition.data_condition + .. autoclass:: pina._src.condition.data_condition.DataCondition :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/data_manager.rst b/docs/source/_rst/condition/data_manager.rst new file mode 100644 index 000000000..66e177854 --- /dev/null +++ b/docs/source/_rst/condition/data_manager.rst @@ -0,0 +1,9 @@ +Data Manager +====================== +.. currentmodule:: pina.condition.data_manager + +.. automodule:: pina._src.condition.data_manager + +.. autoclass:: pina._src.condition.data_manager._DataManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/data_manager_interface.rst b/docs/source/_rst/condition/data_manager_interface.rst new file mode 100644 index 000000000..b1adac823 --- /dev/null +++ b/docs/source/_rst/condition/data_manager_interface.rst @@ -0,0 +1,9 @@ +Data Manager Interface +========================= +.. currentmodule:: pina.condition.data_manager_interface + +.. automodule:: pina._src.condition.data_manager_interface + +.. autoclass:: pina._src.condition.data_manager_interface._DataManagerInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/domain_equation_condition.rst b/docs/source/_rst/condition/domain_equation_condition.rst index 10f1395ca..2c372f13f 100644 --- a/docs/source/_rst/condition/domain_equation_condition.rst +++ b/docs/source/_rst/condition/domain_equation_condition.rst @@ -2,6 +2,8 @@ Domain Equation Condition =========================== .. currentmodule:: pina.condition.domain_equation_condition +.. automodule:: pina._src.condition.domain_equation_condition + .. autoclass:: pina._src.condition.domain_equation_condition.DomainEquationCondition :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/graph_data_manager.rst b/docs/source/_rst/condition/graph_data_manager.rst new file mode 100644 index 000000000..b8b6ba39e --- /dev/null +++ b/docs/source/_rst/condition/graph_data_manager.rst @@ -0,0 +1,9 @@ +Graph Data Manager +====================== +.. currentmodule:: pina.condition.graph_data_manager + +.. automodule:: pina._src.condition.graph_data_manager + +.. autoclass:: pina._src.condition.graph_data_manager._GraphDataManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/input_equation_condition.rst b/docs/source/_rst/condition/input_equation_condition.rst index 9c54da106..da0a48476 100644 --- a/docs/source/_rst/condition/input_equation_condition.rst +++ b/docs/source/_rst/condition/input_equation_condition.rst @@ -2,6 +2,8 @@ Input Equation Condition =========================== .. currentmodule:: pina.condition.input_equation_condition +.. automodule:: pina._src.condition.input_equation_condition + .. autoclass:: pina._src.condition.input_equation_condition.InputEquationCondition :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/input_target_condition.rst b/docs/source/_rst/condition/input_target_condition.rst index 808dd0f06..da8333714 100644 --- a/docs/source/_rst/condition/input_target_condition.rst +++ b/docs/source/_rst/condition/input_target_condition.rst @@ -2,6 +2,8 @@ Input Target Condition =========================== .. currentmodule:: pina.condition.input_target_condition +.. automodule:: pina._src.condition.input_target_condition + .. autoclass:: pina._src.condition.input_target_condition.InputTargetCondition :members: :show-inheritance: diff --git a/docs/source/_rst/condition/tensor_data_manager.rst b/docs/source/_rst/condition/tensor_data_manager.rst new file mode 100644 index 000000000..e45e86c8c --- /dev/null +++ b/docs/source/_rst/condition/tensor_data_manager.rst @@ -0,0 +1,9 @@ +Tensor Data Manager +====================== +.. currentmodule:: pina.condition.tensor_data_manager + +.. automodule:: pina._src.condition.tensor_data_manager + +.. autoclass:: pina._src.condition.tensor_data_manager._TensorDataManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/_src/condition/base_condition.py b/pina/_src/condition/base_condition.py new file mode 100644 index 000000000..013c5bf24 --- /dev/null +++ b/pina/_src/condition/base_condition.py @@ -0,0 +1,153 @@ +"""Module for the Base Condition class.""" + +from functools import partial +import torch +from torch_geometric.data import Batch +from torch.utils.data import DataLoader +from pina._src.condition.condition_interface import ConditionInterface +from pina._src.core.graph import LabelBatch +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency +from pina._src.data.dummy_dataloader import DummyDataloader +from pina._src.problem.problem_interface import ProblemInterface + + +class BaseCondition(ConditionInterface): + """ + Base class for all conditions, implementing common functionality. + + All specific condition types should inherit from this class and implement + the abstract methods of + :class:`~pina.condition.condition_interface.ConditionInterface`. + + This class is not meant to be instantiated directly. + """ + + # Available collate functions for automatic batching + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + """ + Initialization of the :class:`BaseCondition` class. + + :param dict kwargs: The keyword arguments representing the data to be + stored in the condition. + """ + super().__init__() + self.data = self.store_data(**kwargs) + self.has_custom_dataloader_fn = False + + def __len__(self): + """ + Return the number of data points in the condition. + + :return: The number of data points. + :rtype: int + """ + return len(self.data) + + def __getitem__(self, idx): + """ + Return the data point at the specified index. + + :param int idx: The index of the data point to retrieve. + :return: The data point at the specified index. + :rtype: Any + """ + return self.data[idx] + + def create_dataloader( + self, dataset, batch_size, automatic_batching, **kwargs + ): + """ + Create the DataLoader for the condition. + + :param Dataset dataset: The dataset for the DataLoader. + :param int batch_size: The batch size for the DataLoader. + :param bool automatic_batching: Whether to use automatic batching. + :param dict kwargs: Additional keyword arguments for the DataLoader. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + # If batching the entire dataset, return a DummyDataloader + if batch_size == len(dataset): + return DummyDataloader(dataset) + + # Otherwise, return a regular DataLoader with the appropriate collate + return DataLoader( + dataset=dataset, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + batch_size=batch_size, + **kwargs, + ) + + def switch_dataloader_fn(self, create_dataloader_fn): + """ + Switch the dataloader function for the condition. + + :param Callable create_dataloader_fn: The new dataloader function to use + for the condition. + :return: The new dataloader function for the condition. + :rtype: Callable + """ + self.has_custom_dataloader_fn = True + self.create_dataloader = create_dataloader_fn + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function for automatic batching to be used in the DataLoader. + + :param list batch: A list of items from the dataset. + :return: A collated batch. + :rtype: dict + """ + # If the batch is empty, return an empty dictionary + if not batch: + return {} + + # Otherwise, collate the batch using the appropriate collate function + instance_class = batch[0].__class__ + return instance_class.create_batch(batch) + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for custom batching to be used in the DataLoader. + + :param list batch: A list of items from the dataset. + :param BaseCondition condition: The condition instance. + :return: A collated batch. + :rtype: dict + """ + return condition.data[batch].to_batch() + + @property + def problem(self): + """ + The problem associated with this condition. + + :return: The problem associated with this condition. + :rtype: BaseProblem + """ + return self._problem + + @problem.setter + def problem(self, value): + """ + Set the problem associated with this condition. + + :param BaseProblem value: The problem to associate with this condition. + :raises ValueError: If the problem is not an instance of BaseProblem. + """ + check_consistency(value, ProblemInterface) + self._problem = value diff --git a/pina/_src/condition/batch_manager.py b/pina/_src/condition/batch_manager.py index 105eec6eb..cdea44616 100644 --- a/pina/_src/condition/batch_manager.py +++ b/pina/_src/condition/batch_manager.py @@ -1,17 +1,15 @@ -""" -Module for managing batches of data with device transfer capabilities. -""" +"""Module for the Batch Manager class.""" class _BatchManager(dict): """ - A dictionary-based batch manager that supports dot-notation - and moving tensors to devices. + Dict-like container for batched data with attribute-style access and + convenience methods for device placement. """ def to(self, device): """ - Move all tensors in the batch to the specified device. + Move all compatible values in the batch to the specified device. :param device: The target device. :type device: torch.device | str @@ -21,19 +19,25 @@ def to(self, device): for key, value in self.items(): if hasattr(value, "to"): moved_value = value.to(device) - self[key] = moved_value # Updates both dict and attribute + self[key] = moved_value + return self def __getattribute__(self, name): """ - Alias attribute access to dictionary keys. + Provide attribute-style access to dictionary keys. :param str name: The name of the attribute to retrieve. + :raises AttributeError: If the attribute is not found as a standard + attribute or a dictionary key. :return: The value associated with the attribute name. :rtype: Any """ + # First, attempt to retrieve the attribute using the standard method. try: return super().__getattribute__(name) + + # If not found, attempt to retrieve the attribute as a dictionary key. except AttributeError: try: return self[name] diff --git a/pina/_src/condition/condition.py b/pina/_src/condition/condition.py index 71cb80e2f..e4bc62d66 100644 --- a/pina/_src/condition/condition.py +++ b/pina/_src/condition/condition.py @@ -1,11 +1,11 @@ """Module for the Condition class.""" +from pina._src.condition.input_equation_condition import InputEquationCondition +from pina._src.condition.input_target_condition import InputTargetCondition from pina._src.condition.data_condition import DataCondition from pina._src.condition.domain_equation_condition import ( DomainEquationCondition, ) -from pina._src.condition.input_equation_condition import InputEquationCondition -from pina._src.condition.input_target_condition import InputTargetCondition class Condition: @@ -26,7 +26,6 @@ class Condition: arguments, the class automatically selects the appropriate internal implementation. - Available `Condition` types: - :class:`~pina.condition.input_target_condition.InputTargetCondition`: @@ -34,9 +33,8 @@ class Condition: data. The model is trained to reproduce the ``target`` values given the ``input``. Supported data types include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or - :class:`~torch_geometric.data.Data`. - The class automatically selects the appropriate implementation based on - the types of ``input`` and ``target``. + :class:`~torch_geometric.data.Data`. The class automatically selects the + appropriate implementation based on the types of ``input`` and ``target``. - :class:`~pina.condition.domain_equation_condition.DomainEquationCondition` : represents a general physics-informed condition defined by a ``domain`` @@ -60,9 +58,8 @@ class Condition: specified when the model depends on additional parameters. Supported data types include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or - :class:`~torch_geometric.data.Data`. - The class automatically selects the appropriate implementation based on - the type of the ``input``. + :class:`~torch_geometric.data.Data`. The class automatically selects the + appropriate implementation based on the type of the ``input``. .. note:: diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py deleted file mode 100644 index 4a7c8c1c8..000000000 --- a/pina/_src/condition/condition_base.py +++ /dev/null @@ -1,148 +0,0 @@ -""" -Base class for conditions. -""" - -from functools import partial -import torch -from torch_geometric.data import Batch -from torch.utils.data import DataLoader -from pina._src.condition.condition_interface import ConditionInterface -from pina._src.core.graph import LabelBatch -from pina._src.core.label_tensor import LabelTensor -from pina._src.data.dummy_dataloader import DummyDataloader - - -class ConditionBase(ConditionInterface): - """ - Base abstract class for all conditions in PINA. - This class provides common functionality for handling data storage, - batching, and interaction with the associated problem. - """ - - collate_fn_dict = { - "tensor": torch.stack, - "label_tensor": LabelTensor.stack, - "graph": LabelBatch.from_data_list, - "data": Batch.from_data_list, - } - - def __init__(self, **kwargs): - """ - Initialization of the :class:`ConditionBase` class. - - :param kwargs: Keyword arguments representing the data to be stored. - """ - super().__init__() - self.data = self.store_data(**kwargs) - self.has_custom_dataloader_fn = False - - @property - def problem(self): - """ - Return the problem associated with this condition. - - :return: Problem associated with this condition. - :rtype: ~pina.problem.base_problem.BaseProblem - """ - return self._problem - - @problem.setter - def problem(self, value): - """ - Set the problem associated with this condition. - - :param pina.problem.base_problem.BaseProblem value: The problem to - associate with this condition. - """ - self._problem = value - - def __len__(self): - """ - Return the number of data points in the condition. - - :return: Number of data points. - :rtype: int - """ - return len(self.data) - - def __getitem__(self, idx): - """ - Return the data point(s) at the specified index. - - :param idx: Index(es) of the data point(s) to retrieve. - :type idx: int | list[int] - :return: Data point(s) at the specified index. - """ - return self.data[idx] - - @classmethod - def automatic_batching_collate_fn(cls, batch): - """ - Collate function for automatic batching to be used in DataLoader. - :param batch: A list of items from the dataset. - :type batch: list - :return: A collated batch. - :rtype: dict - """ - if not batch: - return {} - instance_class = batch[0].__class__ - batch = instance_class.create_batch(batch) - return batch - - @staticmethod - def collate_fn(batch, condition): - """ - Collate function for custom batching to be used in DataLoader. - - :param batch: A list of items from the dataset. - :type batch: list - :param condition: The condition instance. - :type condition: ConditionBase - :return: A collated batch. - :rtype: dict - """ - data = condition.data[batch].to_batch() - return data - - def create_dataloader( - self, - dataset, - batch_size, - automatic_batching, - **kwargs, - ): - """ - Create a DataLoader for the condition. - - :param int batch_size: The batch size for the DataLoader. - :param bool shuffle: Whether to shuffle the data. Default is ``False``. - :return: The DataLoader for the condition. - :rtype: torch.utils.data.DataLoader - """ - if batch_size == len(dataset): - return DummyDataloader(dataset) - return DataLoader( - dataset=dataset, - collate_fn=( - partial(self.collate_fn, condition=self) - if not automatic_batching - else self.automatic_batching_collate_fn - ), - batch_size=batch_size, - **kwargs, - ) - - def switch_dataloader_fn(self, create_dataloader_fn): - """ - Decorator to switch the dataloader function for a condition. - - :param create_dataloader_fn: The new dataloader function to use. - :type create_dataloader_fn: function - :return: The decorated function with the new dataloader function. - :rtype: function - """ - # Replace the create_dataloader method of the ConditionBase class with - # the new function - self.has_custom_dataloader_fn = True - self.create_dataloader = create_dataloader_fn diff --git a/pina/_src/condition/condition_interface.py b/pina/_src/condition/condition_interface.py index 68898b082..9183d196f 100644 --- a/pina/_src/condition/condition_interface.py +++ b/pina/_src/condition/condition_interface.py @@ -5,53 +5,106 @@ class ConditionInterface(metaclass=ABCMeta): """ - Abstract base class for PINA conditions. All specific conditions must - inherit from this interface. + Abstract interface for all conditions. Refer to :class:`pina.condition.condition.Condition` for a thorough description of all available conditions and how to instantiate them. """ @abstractmethod - def __init__(self, **kwargs): + def __len__(self): """ - Initialization of the :class:`ConditionInterface` class. + Return the number of data points in the condition. + + :return: The number of data points. + :rtype: int """ - @property @abstractmethod - def problem(self): + def __getitem__(self, idx): """ - Return the problem associated with this condition. + Return the data point at the specified index. - :return: Problem associated with this condition. - :rtype: ~pina.problem.base_problem.BaseProblem + :param int idx: The index of the data point to retrieve. + :return: The data point at the specified index. + :rtype: Any """ - @problem.setter @abstractmethod - def problem(self, value): + def store_data(self, **kwargs): """ - Set the problem associated with this condition. + Store the data for the condition in a suitable format. - :param pina.problem.base_problem.BaseProblem value: The problem - to associate with this condition + :param dict kwargs: The keyword arguments containing the data to be + stored. + :return: The stored data in a suitable format. + :rtype: Any """ @abstractmethod - def __len__(self): + def create_dataloader( + self, dataset, batch_size, automatic_batching, **kwargs + ): """ - Return the number of data points in the condition. + Create the DataLoader for the condition. - :return: Number of data points. - :rtype: int + :param Dataset dataset: The dataset for the DataLoader. + :param int batch_size: The batch size for the DataLoader. + :param bool automatic_batching: Whether to use automatic batching. + :param dict kwargs: Additional keyword arguments for the DataLoader. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader """ @abstractmethod - def __getitem__(self, idx): + def switch_dataloader_fn(self, create_dataloader_fn): + """ + Switch the dataloader function for the condition. + + :param Callable create_dataloader_fn: The new dataloader function to use + for the condition. + :return: The new dataloader function for the condition. + :rtype: Callable """ - Return the data point(s) at the specified index. - :param int idx: Index of the data point(s) to retrieve. - :return: Data point(s) at the specified index. + @classmethod + @abstractmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function for automatic batching to be used in the DataLoader. + + :param list batch: A list of items from the dataset. + :return: A collated batch. + :rtype: dict + """ + + @staticmethod + @abstractmethod + def collate_fn(batch, condition): + """ + Collate function for custom batching to be used in the DataLoader. + + :param list batch: A list of items from the dataset. + :param BaseCondition condition: The condition instance. + :return: A collated batch. + :rtype: dict + """ + + @property + @abstractmethod + def problem(self): + """ + The problem associated with this condition. + + :return: The problem associated with this condition. + :rtype: BaseProblem + """ + + @problem.setter + @abstractmethod + def problem(self, value): + """ + Set the problem associated with this condition. + + :param BaseProblem value: The problem to associate with this condition. """ diff --git a/pina/_src/condition/data_condition.py b/pina/_src/condition/data_condition.py index f37b3dc31..da34f838b 100644 --- a/pina/_src/condition/data_condition.py +++ b/pina/_src/condition/data_condition.py @@ -1,14 +1,15 @@ -"""Module for the DataCondition class.""" +"""Module for the Data Condition class.""" import torch from torch_geometric.data import Data -from pina._src.condition.condition_base import ConditionBase +from pina._src.condition.base_condition import BaseCondition from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph from pina._src.condition.data_manager import _DataManager +from pina._src.core.utils import check_consistency -class DataCondition(ConditionBase): +class DataCondition(BaseCondition): """ The class :class:`DataCondition` defines an unsupervised condition based on ``input`` data. This condition is typically used in data-driven problems, @@ -27,94 +28,80 @@ class DataCondition(ConditionBase): >>> condition = Condition(input=pts, conditional_variables=cond_vars) """ - # Available input data types + # Available fields, input and conditional variables data types __fields__ = ["input", "conditional_variables"] - _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) + _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) def __new__(cls, input, conditional_variables=None): """ Check the types of ``input`` and ``conditional_variables`` and - instantiate a class of :class:`DataCondition` accordingly. + instantiate an instance of :class:`DataCondition` accordingly. - :param input: The input data for the condition. + :param input: The input data associated with the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param conditional_variables: The conditional variables for the - condition. Default is ``None``. + :param conditional_variables: The conditional variables associated with + the condition. Default is ``None``. :type conditional_variables: torch.Tensor | LabelTensor - :return: The subclass of DataCondition. - :rtype: pina.condition.data_condition.TensorDataCondition | - pina.condition.data_condition.GraphDataCondition :raises ValueError: If ``input`` is not of type :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, - or :class:`~torch_geometric.data.Data`. + or :class:`~torch_geometric.data.Data`, nor is it a list or tuple of + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`. + :raises ValueError: If ``conditional_variables`` is not of type + :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`. + :return: A new instance of :class:`DataCondition`. + :rtype: DataCondition """ - if cls != DataCondition: - return super().__new__(cls) - - # Check input type - if not isinstance(input, cls._avail_input_cls): - raise ValueError( - "Invalid input type. Expected one of the following: " - "torch.Tensor, LabelTensor, Graph, Data or " - "an iterable of the previous types." - ) + # Check input type - if iterable, ensure it is either Data or Graph if isinstance(input, (list, tuple)): - for item in input: - if not isinstance(item, (Data, Graph)): - raise ValueError( - "if input is a list or tuple, all its elements must" - " be of type Graph or Data." - ) + check_consistency(input, (Data, Graph)) + else: + check_consistency(input, cls._avail_input_cls) # Check conditional_variables type if conditional_variables is not None: - if not isinstance( + check_consistency( conditional_variables, cls._avail_conditional_variables_cls - ): - raise ValueError( - "Invalid conditional_variables type. Expected one of the " - "following: torch.Tensor, LabelTensor." - ) + ) return super().__new__(cls) def store_data(self, **kwargs): """ - Store the input data and conditional variables in a dictionary. + Store the input data and the conditional variables in a dictionary-like + structure. - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | - Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param conditional_variables: The conditional variables for the - condition. - :type conditional_variables: torch.Tensor | LabelTensor - :return: A dictionary containing the stored data. - :rtype: dict + :param dict kwargs: The keyword arguments containing the data to be + stored. + :return: A dictionary-like structure containing the stored data. + :rtype: _DataManager """ + # Store input and conditional variables in a dictionary-like structure data_dict = {"input": kwargs.get("input")} cond_vars = kwargs.get("conditional_variables", None) if cond_vars is not None: data_dict["conditional_variables"] = cond_vars + return _DataManager(**data_dict) @property def conditional_variables(self): """ - Return the conditional variables for the condition. + The conditional variables associated with the condition. :return: The conditional variables. :rtype: torch.Tensor | LabelTensor | None """ if hasattr(self.data, "conditional_variables"): return self.data.conditional_variables + return None @property def input(self): """ - Return the input data for the condition. + The input data associated with the condition. :return: The input data. :rtype: torch.Tensor | LabelTensor | Graph | Data | diff --git a/pina/_src/condition/data_manager.py b/pina/_src/condition/data_manager.py index 2f7095fa1..723a4f059 100644 --- a/pina/_src/condition/data_manager.py +++ b/pina/_src/condition/data_manager.py @@ -1,349 +1,50 @@ -""" -Module for managing data in conditions. -""" +"""Module for the Data Manager factory class.""" import torch -from torch_geometric.data import Data -from torch_geometric.data.batch import Batch -from pina import LabelTensor -from pina._src.core.graph import Graph, LabelBatch +from pina._src.core.label_tensor import LabelTensor from pina._src.equation.base_equation import BaseEquation -from .batch_manager import _BatchManager +from pina._src.condition.graph_data_manager import _GraphDataManager +from pina._src.condition.tensor_data_manager import _TensorDataManager class _DataManager: """ - Abstract base class for data managers. + Factory class for data manager implementations. - This class dynamically selects between :class:`_TensorDataManager` and - :class:`_GraphDataManager` based on the types of the input data. + This class dispatches object creation to either + :class:`~pina.condition.tensor_data_manager._TensorDataManager` or + :class:`~pina.condition.graph_data_manager._GraphDataManager` depending on + the types of the provided keyword arguments. """ def __new__(cls, **kwargs): """ - Dynamically instantiate the appropriate subclass based on the types - of the input data. - - If all values in ``kwargs`` are instances of - :class:`torch.Tensor`, :class:`LabelTensor` then - :class:`_TensorDataManager` is instantiated. - - Otherwise, :class:`_GraphDataManager` is instantiated. + Create the appropriate data manager implementation based on the provided + keyword arguments. - :param dict kwargs: The keyword arguments containing the data. - :return: An instance of :class:`_TensorDataManager` or - :class:`_GraphDataManager`. + If all values in ``kwargs`` are instances of :class:`torch.Tensor`, + :class:`~pina.label_tensor.LabelTensor`, or + :class:`~pina.equation.base_equation.BaseEquation`, an instance of + :class:`~pina.condition.tensor_data_manager._TensorDataManager` is + created. Otherwise, an instance of + :class:`~pina.condition.graph_data_manager._GraphDataManager` is + created. + + :param dict kwargs: The keyword arguments for the data manager. + :return: A concrete data manager instance. :rtype: _TensorDataManager | _GraphDataManager """ - # If not called directly, proceed with normal instantiation + # Guard subclass instantiation if cls is not _DataManager: return super().__new__(cls) - # Does the data contain only tensors/LabelTensors/Equations? + # Check if there are only tensors / equations is_tensor_only = all( isinstance(v, (torch.Tensor, LabelTensor, BaseEquation)) for v in kwargs.values() ) - # Choose the appropriate subclass, GraphDataManager or TensorDataManager - subclass = _TensorDataManager if is_tensor_only else _GraphDataManager - return super().__new__(subclass) - - def __init__(self, **kwargs): - """ - Initialize the data manager with the provided keyword arguments. - - :param dict kwargs: The keyword arguments containing the data. - """ - self.keys = list(kwargs.keys()) - - -class _TensorDataManager(_DataManager): - """ - Data manager for tensor data. Handles data stored as `torch.Tensor` or - `LabelTensor`. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.data = kwargs - - for k, v in kwargs.items(): - setattr(self, k, v) - - def __len__(self): - """ - Return the number of samples in the tensor data manager. - - :return: Number of samples. - :rtype: int - """ - return self.data[self.keys[0]].shape[0] - - def __getitem__(self, idx): - """ - Return a data item or a subset of data items by index. - - :param idx: Index or indices of the data items to retrieve. - :type idx: int | slice | list[int] | torch.Tensor - :return: A new :class:`_TensorDataManager` instance containing the - selected data items. - :rtype: _TensorDataManager - """ - # Mapping efficiente degli elementi - new_data = { - k: (self.data[k][idx] if k in self.keys else self.data[k]) - for k in self.keys - } - return _TensorDataManager(**new_data) - - @staticmethod - def create_batch(items): - """ - Create a batch from a list of :class:`_TensorDataManager` items. - - :param list items: List of :class:`_TensorDataManager` items to batch. - :return: A new :class:`_BatchManager` instance containing the batched - data. - :rtype: _BatchManager - """ - if not items: - return None - first = items[0] - batch_data = _BatchManager() - - for k in first.keys: - vals = [it.data[k] for it in items] - sample = vals[0] - - if isinstance(sample, (torch.Tensor, LabelTensor)): - batch_fn = ( - LabelTensor.stack - if isinstance(sample, LabelTensor) - else torch.stack - ) - batch_data[k] = batch_fn(vals) - batch_data[k] = batch_fn(vals, dim=0) - else: - batch_data[k] = sample - return batch_data - - def to_batch(self): - """ - Create a batch from the current tensor data manager. - - :return: A new :class:`_BatchManager` instance containing the batched - data. - :rtype: _BatchManager - """ - batch_data = _BatchManager() - for k in self.keys: - batch_data[k] = self.data[k] - return batch_data - - -class _GraphDataManager(_DataManager): - """ - Data manager for graph data. Handles data stored as :class:`Graph`, - :class:`Data`, or lists/tuples of these types. Moreover , it can also manage - associated tensors stored as :class:`torch.Tensor` or :class:`LabelTensor`. - """ - - def __init__(self, **kwargs): - """ - Initialize the graph data manager with the provided keyword arguments. - - :param dict kwargs: The keyword arguments containing the data. - """ - super().__init__(**kwargs) - self.graph_key = next( - k - for k, v in kwargs.items() - if isinstance(v, (Graph, Data, list, tuple)) - ) - - self.keys = [ - k - for k in self.keys - if k != self.graph_key - and isinstance(kwargs[k], (torch.Tensor, LabelTensor)) - ] - # Prepare graphs and assign tensors - self.data = self._prepare_graphs(kwargs) - - def _prepare_graphs(self, kwargs): - """ - Store tensors in the corresponding graphs. - - :param dict kwargs: The keyword arguments containing the graphs and - associated tensors. - :return: A list of graphs with tensors assigned. - :rtype: list[Graph] | list[Data] - """ - graphs = kwargs.pop(self.graph_key) - if not isinstance(graphs, (list, tuple)): - graphs = [graphs] - - n_graphs = len(graphs) - for name, tensor in kwargs.items(): - # Verify consistency between number of graphs and tensor samples - if n_graphs != tensor.shape[0]: - raise ValueError( - f"Number of graphs ({n_graphs}) does not match " - f"number of samples for key '{name}' " - f"({kwargs[name].shape[0]})." - ) - # Assign tensors to graphs - for i, g in enumerate(graphs): - setattr(g, name, tensor[i]) - - return graphs - - def __len__(self): - """ - Return the number of graphs in the graph data manager. - - :return: Number of graphs. - :rtype: int - """ - return len(self.data) - - def __getattr__(self, name): - """ - Override attribute access to retrieve tensors or graphs. If the graph - key is requested, return the list of graphs. If a tensor key is - requested, stack the tensors from all graphs and return the result. - - :param str name: The name of the attribute to retrieve. - :return: The requested tensor or graph. - :rtype: torch.Tensor | LabelTensor | Graph | list[Graph] | Data | - """ - # If the requested attribute is a tensor key, stack the tensors from - # all graphs - if name in self.keys: - tensors = [getattr(g, name) for g in self.data] - batch_fn = ( - LabelTensor.stack - if isinstance(tensors[0], LabelTensor) - else torch.stack - ) - return batch_fn(tensors) - - # If the requested attribute is the graph key, return the graphs - if name == self.graph_key: - return self.data if len(self.data) > 1 else self.data[0] - - return super().__getattribute__(name) - - @classmethod - def _init_from_graphs_list(cls, graphs, graph_key, keys): - """ - Initialize a :class:`_GraphDataManager` instance from a list of graphs. - This is used internally to create subsets of the data manager, without - going through the full initialization process. - - :param list graphs: List of graphs to initialize the data manager with. - :param str graph_key: Key under which the graphs are stored. - :param list keys: List of tensor keys associated with the graphs. - :return: A new :class:`_GraphDataManager` instance. - :rtype: _GraphDataManager - """ - # Create a new instance without calling __init__ - obj = _GraphDataManager.__new__(_GraphDataManager) - obj.graph_key = graph_key - obj.keys = keys - obj.data = graphs - return obj - - def __getitem__(self, idx): - """ - Retrieve a graph or a subset of graphs by index. - - :param idx: Index or indices of the graphs to retrieve. - :type idx: int | slice | list[int] | torch.Tensor - :return: A new :class:`_GraphDataManager` instance containing the - selected graphs. - :rtype: _GraphDataManager - """ - # Manage int and slice directly - if isinstance(idx, (int, slice)): - selected = self.data[idx] - # Manage list or tensor of indices - elif isinstance(idx, (list, torch.Tensor)): - selected = [self.data[i] for i in idx] - else: - raise TypeError(f"Invalid index type: {type(idx)}") - - # Ensure selected is a list - if not isinstance(selected, list): - selected = [selected] - - # Return a new _GraphDataManager instance with the selected graphs - return _GraphDataManager._init_from_graphs_list( - selected, - # tensor_keys=self._tensor_keys, - graph_key=self.graph_key, - keys=self.keys, - ) - - def to_batch(self): - """ - Create a batch from the current graph data manager. - - :return: A new :class:`_BatchManager` instance containing the batched - data. - :rtype: _BatchManager - """ - batching_fn = ( - LabelBatch.from_data_list - if isinstance(self.data[0], Graph) - else Batch.from_data_list - ) - - batched_graph = batching_fn(self.data) - batch_data = _BatchManager() - for k in self.keys: - if k == self.graph_key: - continue - batch_data[k] = getattr(batched_graph, k) - delattr(batched_graph, k) - batch_data[self.graph_key] = batched_graph - return batch_data - - @staticmethod - def create_batch(items): - """ - Optimized batch creation. - """ - if not items: - return None - - first = items[0] - graph_key = first.graph_key - # Determine batching function once - is_labeled = isinstance(first.data[0], Graph) - batching_fn = ( - LabelBatch.from_data_list if is_labeled else Batch.from_data_list - ) - - # Efficient list comprehension for extraction - # If to_batch() is called on self, self.data might be a list already. - # If _create_batch is called on multiple managers, we grab the first - # graph from each. - graphs_to_batch = [item.data[0] for item in items] - batched_graph = batching_fn(graphs_to_batch) - - batch_data = _BatchManager() - - # Use a set for O(1) lookups if keys is large - keys_to_transfer = set(first.keys) - if graph_key in keys_to_transfer: - keys_to_transfer.remove(graph_key) - - for k in keys_to_transfer: - # Check if attribute exists once to avoid AttributeError overhead - val = getattr(batched_graph, k, None) - if val is not None: - batch_data[k] = val - delattr(batched_graph, k) + # Choose the appropriate subclass + subclass = _TensorDataManager if is_tensor_only else _GraphDataManager - batch_data[graph_key] = batched_graph - return batch_data + return subclass(**kwargs) diff --git a/pina/_src/condition/data_manager_interface.py b/pina/_src/condition/data_manager_interface.py new file mode 100644 index 000000000..2e51dd3a1 --- /dev/null +++ b/pina/_src/condition/data_manager_interface.py @@ -0,0 +1,53 @@ +"""Module for the Tensor-Data Manager interface.""" + +from abc import ABCMeta, abstractmethod + + +class _DataManagerInterface(metaclass=ABCMeta): + """ + Abstract interface for all data managers. + """ + + @abstractmethod + def __len__(self): + """ + Return the number of samples in the data manager. + + :return: The number of samples. + :rtype: int + """ + + @abstractmethod + def __getitem__(self, idx): + """ + Return the item at the specified indices. + + :param idx: The indices of the data point to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_DataManager` instance containing the + selected data items. + :rtype: _DataManager + """ + + @abstractmethod + def to_batch(self): + """ + Create a batch from the current data manager. + + :return: A new :class:`~pina.condition.data_manager._DataManager` + instance with batched data. + :rtype: _DataManager + """ + + @staticmethod + @abstractmethod + def create_batch(items): + """ + Create a batch from a list of :class:`_DataManager` items. + + :param list[_DataManager] items: A list of + :class:`_DataManager` items to batch. + :return: A new instance of :class:`_DataManager` containing the + batched data. + :rtype: _DataManager + """ diff --git a/pina/_src/condition/domain_equation_condition.py b/pina/_src/condition/domain_equation_condition.py index 42b448ce6..73307159b 100644 --- a/pina/_src/condition/domain_equation_condition.py +++ b/pina/_src/condition/domain_equation_condition.py @@ -1,11 +1,12 @@ -"""Module for the DomainEquationCondition class.""" +"""Module for the Domain-Equation Condition class.""" -from pina._src.condition.condition_base import ConditionBase +from pina._src.condition.base_condition import BaseCondition from pina._src.domain.domain_interface import DomainInterface from pina._src.equation.base_equation import BaseEquation +from pina._src.core.utils import check_consistency -class DomainEquationCondition(ConditionBase): +class DomainEquationCondition(BaseCondition): """ The class :class:`DomainEquationCondition` defines a condition based on a ``domain`` and an ``equation``. This condition is typically used in @@ -28,68 +29,95 @@ class DomainEquationCondition(ConditionBase): >>> condition = Condition(domain=domain, equation=Equation(dummy_equation)) """ - # Available slots + # Available fields, domain and equation data types __fields__ = ["domain", "equation"] - _avail_domain_cls = (DomainInterface, str) _avail_equation_cls = BaseEquation - def __new__(cls, domain, equation): - """ - Check the types of ``domain`` and ``equation`` and instantiate an - instance of :class:`DomainEquationCondition`. - - :return: An instance of :class:`DomainEquationCondition`. - :rtype: pina.condition.domain_equation_condition.DomainEquationCondition - :raises ValueError: If ``domain`` is not of type - :class:`DomainInterface` or - ``equation`` is not of type :class:` - """ - if not isinstance(domain, cls._avail_domain_cls): - raise ValueError( - "The domain must be an instance of DomainInterface." - ) - - if not isinstance(equation, cls._avail_equation_cls): - raise ValueError( - "The equation must be an instance of BaseEquation." - ) - - return super().__new__(cls) - def __len__(self): """ - Raise NotImplementedError since the number of points is determined by - the domain sampling strategy. + Return the number of data points in the condition. :raises NotImplementedError: Always raised since the number of points is - determined by the domain sampling strategy. + determined by the domain sampling strategy and is not fixed. """ raise NotImplementedError( - "`__len__` method is not implemented for " - "`DomainEquationCondition` since the number of points is " - "determined by the domain sampling strategy." + "The number of data points in a DomainEquationCondition is not " + "fixed and is determined by the domain sampling strategy. " + "Therefore, the :meth:`__len__` method is not implemented for this " + "condition." ) def __getitem__(self, idx): """ - Raise NotImplementedError since data retrieval is not applicable. + Return the data point at the specified index. - :param int idx: Index of the data point(s) to retrieve. - :raises NotImplementedError: Always raised since data retrieval is not - applicable for this condition. + :raises NotImplementedError: Always raised since the data points are not + stored in a list-like structure and cannot be accessed by index. """ raise NotImplementedError( - "`__getitem__` method is not implemented for " - "`DomainEquationCondition`" + "Data points in a DomainEquationCondition are not stored in a " + "list-like structure and cannot be accessed by index. Therefore, " + "the :meth:`__getitem__` method is not implemented for this " + "condition." ) def store_data(self, **kwargs): """ - Store data for the condition. No data is stored for this condition. + Store the domain and the equation for the condition. It sets the + attributes ``domain`` and ``equation`` of the condition instance based + on the provided keyword arguments. - :return: An empty dictionary since no data is stored. - :rtype: dict + :param dict kwargs: The keyword arguments containing the data to be + stored. """ + # Store domain and equation as attributes of the condition instance setattr(self, "domain", kwargs.get("domain")) setattr(self, "equation", kwargs.get("equation")) + + @property + def equation(self): + """ + The equation associated with the condition. + + :return: The equation. + :rtype: BaseEquation + """ + return self._equation + + @equation.setter + def equation(self, value): + """ + Set the equation associated with this condition. + + :param BaseEquation value: The equation to associate with the condition. + :raises ValueError: If ``value`` is not an instance of + :class:`~pina.equation.base_equation.BaseEquation`. + """ + # Check consistency + check_consistency(value, self._avail_equation_cls) + self._equation = value + + @property + def domain(self): + """ + The domain associated with the condition. + + :return: The domain. + :rtype: DomainInterface + """ + return self._domain + + @domain.setter + def domain(self, value): + """ + Set the domain associated with this condition. + + :param DomainInterface value: The domain to associate with the + condition. + :raises ValueError: If ``value`` is neither a string nor an instance of + :class:`~pina.domain.domain_interface.DomainInterface`. + """ + # Check consistency + check_consistency(value, self._avail_domain_cls) + self._domain = value diff --git a/pina/_src/condition/graph_data_manager.py b/pina/_src/condition/graph_data_manager.py new file mode 100644 index 000000000..b05ac5c7a --- /dev/null +++ b/pina/_src/condition/graph_data_manager.py @@ -0,0 +1,246 @@ +"""Module for the Graph-Data Manager class.""" + +import torch +from torch_geometric.data import Data +from torch_geometric.data.batch import Batch +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.graph import Graph, LabelBatch +from pina._src.condition.batch_manager import _BatchManager +from pina._src.condition.data_manager_interface import _DataManagerInterface + + +class _GraphDataManager(_DataManagerInterface): + """ + Data manager for graph-based data. It handles inputs stored as + :class:`Graph`, :class:`Data`, or lists / tuples of these types. + """ + + def __init__(self, **kwargs): + """ + Initialization of the :class:`_GraphDataManager` class. + + :param dict kwargs: The keyword arguments for the graph data manager. + """ + # Initialize keys + self.keys = list(kwargs.keys()) + + # Find graph-based data + self.graph_key = next( + k + for k, v in kwargs.items() + if isinstance(v, (Graph, Data, list, tuple)) + ) + + # Find tensor data + self.keys = [ + k + for k in self.keys + if k != self.graph_key + and isinstance(kwargs[k], (torch.Tensor, LabelTensor)) + ] + + # Prepare graphs and assign tensors + self.data = self._prepare_graphs(kwargs) + + def __len__(self): + """ + Return the number of samples in the graph data manager. + + :return: The number of samples. + :rtype: int + """ + return len(self.data) + + def __getitem__(self, idx): + """ + Return the item at the specified indices. + + :param idx: The indices of the graphs to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :raises TypeError: If an index with invalid type is passed. + :return: A new :class:`_GraphDataManager` instance containing the + selected graphs. + :rtype: _GraphDataManager + """ + # Selection for integers or slices + if isinstance(idx, (int, slice)): + selected = self.data[idx] + + # Selection for lists or tensors + elif isinstance(idx, (list, torch.Tensor)): + selected = [self.data[i] for i in idx] + + # Raise TypeError if index type is invalid + else: + raise TypeError(f"Invalid index type: {type(idx)}") + + # Ensure selected is a list + if not isinstance(selected, list): + selected = [selected] + + return _GraphDataManager._init_from_graphs_list( + selected, graph_key=self.graph_key, keys=self.keys + ) + + def __getattr__(self, name): + """ + Provide dynamic access to stored graph and tensor data. + + If ``name`` corresponds to the graph key, return the list of graph + objects. If it matches a tensor key, retrieve the corresponding + tensors from all graphs and stack them along the batch dimension. + + :param str name: The name of the attribute to access. + :return: The requested graph data or stacked tensor values. + :rtype: torch.Tensor | LabelTensor | list[Graph] | list[Data] + """ + # Stack tensors from all graph if name is a tensor key + if name in self.keys: + tensors = [getattr(g, name) for g in self.data] + batch_fn = ( + LabelTensor.stack + if isinstance(tensors[0], LabelTensor) + else torch.stack + ) + return batch_fn(tensors) + + # Otherwise, return graphs + if name == self.graph_key: + return self.data if len(self.data) > 1 else self.data[0] + + return super().__getattribute__(name) + + def _prepare_graphs(self, kwargs): + """ + Attach tensor data to the corresponding graph objects. + + :param kwargs: The keyword arguments containing graph data and + associated tensor features. + :raises ValueError: If the number of graphs does not match the number of + samples in the tensor of features to associate. + :return: A list of graphs with the corresponding tensors assigned. + :rtype: list[Graph] | list[Data] + """ + # Get graph-based data and store in a list + graphs = kwargs.pop(self.graph_key) + if not isinstance(graphs, (list, tuple)): + graphs = [graphs] + + # Iterate of items + for name, tensor in kwargs.items(): + + # Verify the consistency between the number of graphs and samples + if len(graphs) != tensor.shape[0]: + raise ValueError( + f"Number of graphs ({len(graphs)}) does not match " + f"number of samples for key '{name}' " + f"({kwargs[name].shape[0]})." + ) + + # Assign tensors to graphs + for i, g in enumerate(graphs): + setattr(g, name, tensor[i]) + + return graphs + + def to_batch(self): + """ + Create a batch from the current graph data manager. + + :return: A new instance of :class:`_BatchManager` with batched data. + :rtype: _BatchManager + """ + # Define the batch function + batching_fn = ( + LabelBatch.from_data_list + if isinstance(self.data[0], Graph) + else Batch.from_data_list + ) + + # Create the batch manager + batch_data = _BatchManager() + batched_graph = batching_fn(self.data) + for k in self.keys: + if k == self.graph_key: + continue + batch_data[k] = getattr(batched_graph, k) + delattr(batched_graph, k) + batch_data[self.graph_key] = batched_graph + + return batch_data + + @staticmethod + def create_batch(items): + """ + Create a batch from a list of :class:`_GraphDataManager` items. + + :param list[_GraphDataManager] items: A list of + :class:`_GraphDataManager` items to batch. + :return: A new instance of :class:`_BatchManager` containing the batched + data. + :rtype: _BatchManager + """ + # Return None if no items are provided + if not items: + return None + + # Retrieve the first _GraphDataManager of the list and corresponding key + first = items[0] + graph_key = first.graph_key + + # Initialize the batch manager + batch_data = _BatchManager() + + # Define batch function + batching_fn = ( + LabelBatch.from_data_list + if isinstance(first.data[0], Graph) + else Batch.from_data_list + ) + + # Batch over graphs + batched_graph = batching_fn([item.data[0] for item in items]) + + # Use a set for O(1) lookups if keys are large + keys_to_transfer = set(first.keys) + if graph_key in keys_to_transfer: + keys_to_transfer.remove(graph_key) + + # Iterate over the keys of the _GraphDataManager + for k in keys_to_transfer: + + # Extract values + val = getattr(batched_graph, k, None) + if val is not None: + batch_data[k] = val + delattr(batched_graph, k) + + # Assign key to batch + batch_data[graph_key] = batched_graph + + return batch_data + + @classmethod + def _init_from_graphs_list(cls, graphs, graph_key, keys): + """ + Create a :class:`_GraphDataManager` instance directly from a list of + graph objects. + + This method bypasses the standard initialization logic and is used + internally to construct new instances (e.g., subsets) from already + processed graph data. + + :param list graphs: A list of graph objects. + :param str graph_key: The name of the attribute used to store the + graphs. + :param list keys: A list of tensor keys associated with the graphs. + :return: A new instance of :class:`_GraphDataManager`. + :rtype: _GraphDataManager + """ + # Create a new instance without calling __init__ + obj = _GraphDataManager.__new__(_GraphDataManager) + obj.graph_key = graph_key + obj.keys = keys + obj.data = graphs + + return obj diff --git a/pina/_src/condition/input_equation_condition.py b/pina/_src/condition/input_equation_condition.py index 965501e1a..26958fb08 100644 --- a/pina/_src/condition/input_equation_condition.py +++ b/pina/_src/condition/input_equation_condition.py @@ -1,13 +1,14 @@ -"""Module for the InputEquationCondition class and its subclasses.""" +"""Module for the Input-Equation Condition class.""" -from pina._src.condition.condition_base import ConditionBase +from pina._src.condition.base_condition import BaseCondition from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph from pina._src.equation.base_equation import BaseEquation from pina._src.condition.data_manager import _DataManager +from pina._src.core.utils import check_consistency -class InputEquationCondition(ConditionBase): +class InputEquationCondition(BaseCondition): """ The class :class:`InputEquationCondition` defines a condition based on ``input`` data and an ``equation``. This condition is typically used in @@ -29,55 +30,55 @@ class InputEquationCondition(ConditionBase): >>> condition = Condition(input=pts, equation=Equation(dummy_equation)) """ - # Available input data types + # Available fields, input and equation data types __fields__ = ["input", "equation"] _avail_input_cls = (LabelTensor, Graph) _avail_equation_cls = BaseEquation def __new__(cls, input, equation): """ - Check the types of ``input`` and ``equation`` and instantiate a class - of :class:`InputEquationCondition` accordingly. - - :param input: The input data for the condition. - :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] - :param BaseEquation equation: The equation to be satisfied over the - specified ``input`` data. - :return: The subclass of InputEquationCondition. - :rtype: pina.condition.input_equation_condition. - InputTensorEquationCondition | - pina.condition.input_equation_condition.InputGraphEquationCondition - - :raises ValueError: If input is not of type :class:`~pina.graph.Graph` - or :class:`~pina.label_tensor.LabelTensor`. + Check the types of ``input`` and ``equation`` and instantiate an + instance of :class:`InputEquationCondition` accordingly. + + :param input: The input data associated with the condition. + :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] + :param BaseEquation equation: The equation associated with the + condition. + :raises ValueError: If ``input`` is not an instance of + :class:`~pina.label_tensor.LabelTensor`, or + :class:`~pina.graph.Graph`, nor a list or tuple of + :class:`~pina.graph.Graph`. + :raises ValueError: If ``equation`` is not an instance of + :class:`~pina.equation.base_equation.BaseEquation`. + :return: A new instance of :class:`InputEquationCondition`. + :rtype: InputEquationCondition """ - - # CHeck input type - if not isinstance(input, cls._avail_input_cls): - raise ValueError( - "The input data object must be a LabelTensor or a Graph object." - ) - - # Check equation type - if not isinstance(equation, cls._avail_equation_cls): - raise ValueError( - "The equation must be an instance of BaseEquation." - ) + # Check input type - equation is checked in the setter + if isinstance(input, (list, tuple)): + check_consistency(input, Graph) + else: + check_consistency(input, cls._avail_input_cls) return super().__new__(cls) def store_data(self, **kwargs): """ - Store the input data in a :class:`_DataManager` object. - :param dict kwargs: The keyword arguments containing the input data. + Store the input data in a dictionary-like structure. + + :param dict kwargs: The keyword arguments containing the data to be + stored. + :return: A dictionary-like structure containing the stored data. + :rtype: _DataManager """ + # Save the equation as an attribute of the condition instance setattr(self, "equation", kwargs.pop("equation")) + return _DataManager(**kwargs) @property def input(self): """ - Return the input data for the condition. + The input data associated with the condition. :return: The input data. :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] @@ -87,9 +88,9 @@ def input(self): @property def equation(self): """ - Return the equation associated with this condition. + The equation associated with the condition. - :return: Equation associated with this condition. + :return: The equation. :rtype: BaseEquation """ return self._equation @@ -99,9 +100,10 @@ def equation(self, value): """ Set the equation associated with this condition. - :param BaseEquation value: The equation to associate with this - condition + :param BaseEquation value: The equation to associate with the condition. + :raises ValueError: If ``value`` is not an instance of + :class:`~pina.equation.base_equation.BaseEquation`. """ - if not isinstance(value, BaseEquation): - raise TypeError("The equation must be an instance of BaseEquation.") + # Check consistency + check_consistency(value, self._avail_equation_cls) self._equation = value diff --git a/pina/_src/condition/input_target_condition.py b/pina/_src/condition/input_target_condition.py index dd81cd252..4b641e528 100644 --- a/pina/_src/condition/input_target_condition.py +++ b/pina/_src/condition/input_target_condition.py @@ -1,16 +1,15 @@ -""" -This module contains condition classes for supervised learning tasks. -""" +"""Module for the Input-Target Condition class.""" import torch from torch_geometric.data import Data from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph -from pina._src.condition.condition_base import ConditionBase +from pina._src.condition.base_condition import BaseCondition from pina._src.condition.data_manager import _DataManager +from pina._src.core.utils import check_consistency -class InputTargetCondition(ConditionBase): +class InputTargetCondition(BaseCondition): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -32,69 +31,62 @@ class InputTargetCondition(ConditionBase): >>> condition = Condition(input=input, target=graph) """ - # Available input and target data types + # Available fields, input, and target data types __fields__ = ["input", "target"] - _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) - _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) + _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph) + _avail_target_cls = (torch.Tensor, LabelTensor, Data, Graph) def __new__(cls, input, target): """ - Check the types of ``input`` and ``target`` data and instantiate the - :class:`InputTargetCondition`. + Check the types of ``input`` and ``target`` data and instantiate an + instance of :class:`InputTargetCondition` accordingly. - :param input: The input data for the condition. + :param input: The input data associated with the condition. :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data for the condition. + :param target: The target data associated with the condition. :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :return: An instance of :class:`InputTargetCondition`. - :rtype: pina.condition.input_target_condition.InputTargetCondition - :raises ValueError: If ``input`` or ``target`` are not of supported types. + :raises ValueError: If ``input`` is not of type :class:`torch.Tensor`, + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, + or :class:`~torch_geometric.data.Data`, nor is it a list or tuple of + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`. + :raises ValueError: If ``target`` is not of type :class:`torch.Tensor`, + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, + or :class:`~torch_geometric.data.Data`, nor is it a list or tuple of + :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`. + :return: A new instance of :class:`InputTargetCondition`. + :rtype: InputTargetCondition """ - - if not isinstance(input, cls._avail_input_cls): - raise ValueError( - "Invalid input type. Expected one of the following: " - "torch.Tensor, LabelTensor, Graph, Data or " - "list/tuple of Graph/Data objects." - ) + # Check input type - if iterable, ensure it is either Data or Graph if isinstance(input, (list, tuple)): - for item in input: - if not isinstance(item, (Graph, Data)): - raise ValueError( - "If target is a list or tuple, all its elements " - "must be of type Graph or Data." - ) - - if not isinstance(target, cls._avail_output_cls): - raise ValueError( - "Invalid target type. Expected one of the following: " - "torch.Tensor, LabelTensor, Graph, Data or " - "list/tuple of Graph/Data objects." - ) + check_consistency(input, (Data, Graph)) + else: + check_consistency(input, cls._avail_input_cls) + + # Check target type - if iterable, ensure it is either Data or Graph if isinstance(target, (list, tuple)): - for item in target: - if not isinstance(item, (Graph, Data)): - raise ValueError( - "If target is a list or tuple, all its elements " - "must be of type Graph or Data." - ) + check_consistency(target, (Data, Graph)) + else: + check_consistency(target, cls._avail_target_cls) return super().__new__(cls) def store_data(self, **kwargs): """ - Store the input and target data in a :class:`_DataManager` object. - :param dict kwargs: The keyword arguments containing the input and - target data. + Store the input and target data in a dictionary-like structure. + + :param dict kwargs: The keyword arguments containing the data to be + stored. + :return: A dictionary-like structure containing the stored data. + :rtype: _DataManager """ return _DataManager(**kwargs) @property def input(self): """ - Return the input data for the condition. + The input data associated with the condition. :return: The input data. :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | @@ -105,7 +97,7 @@ def input(self): @property def target(self): """ - Return the target data for the condition. + The target data associated with the condition. :return: The target data. :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | diff --git a/pina/_src/condition/tensor_data_manager.py b/pina/_src/condition/tensor_data_manager.py new file mode 100644 index 000000000..a1ec0b023 --- /dev/null +++ b/pina/_src/condition/tensor_data_manager.py @@ -0,0 +1,110 @@ +"""Module for the Tensor-Data Manager class.""" + +import torch +from pina._src.core.label_tensor import LabelTensor +from pina._src.condition.batch_manager import _BatchManager +from pina._src.condition.data_manager_interface import _DataManagerInterface + + +class _TensorDataManager(_DataManagerInterface): + """ + Data manager for tensor-based data. It handles inputs stored as + :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`. + """ + + def __init__(self, **kwargs): + """ + Initialization of the :class:`_TensorDataManager` class. + + :param dict kwargs: The keyword arguments for the tensor data manager. + """ + self.keys = list(kwargs.keys()) + self.data = kwargs + + # Set attributes from kwargs + for k, v in kwargs.items(): + setattr(self, k, v) + + def __len__(self): + """ + Return the number of samples in the tensor data manager. + + :return: The number of samples. + :rtype: int + """ + return self.data[self.keys[0]].shape[0] + + def __getitem__(self, idx): + """ + Return the item at the specified indices. + + :param idx: The indices of the data point to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_TensorDataManager` instance containing the + selected data items. + :rtype: _TensorDataManager + """ + # Get data at selected indices + new_data = { + k: (self.data[k][idx] if k in self.keys else self.data[k]) + for k in self.keys + } + + return _TensorDataManager(**new_data) + + def to_batch(self): + """ + Create a batch from the current tensor data manager. + + :return: A new instance of :class:`_BatchManager` with batched data. + :rtype: _BatchManager + """ + # Create the batch manager + batch_data = _BatchManager() + for k in self.keys: + batch_data[k] = self.data[k] + + return batch_data + + @staticmethod + def create_batch(items): + """ + Create a batch from a list of :class:`_TensorDataManager` items. + + :param list[_TensorDataManager] items: A list of + :class:`_TensorDataManager` items to batch. + :return: A new instance of :class:`_BatchManager` containing the batched + data. + :rtype: _BatchManager + """ + # Return None if no items are provided + if not items: + return None + + # Retrieve the first _TensorDataManager of the list + first = items[0] + + # Initialize the batch manager + batch_data = _BatchManager() + + # Iterate over the keys of the _TensorDataManager + for k in first.keys: + + # Extract values and a sample used to determine the batch function + vals = [it.data[k] for it in items] + sample = vals[0] + + # Define the batch function based on the data type + if isinstance(sample, (torch.Tensor, LabelTensor)): + batch_fn = ( + LabelTensor.stack + if isinstance(sample, LabelTensor) + else torch.stack + ) + batch_data[k] = batch_fn(vals) + + # If no tensor is provided, just take the first value + else: + batch_data[k] = sample + + return batch_data diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 0cdf7a977..64b72901f 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -7,17 +7,22 @@ """ __all__ = [ - "Condition", "ConditionInterface", - "ConditionBase", + "BaseCondition", + "Condition", "DomainEquationCondition", "InputTargetCondition", "InputEquationCondition", "DataCondition", + "_DataManagerInterface", + "_DataManager", + "_GraphDataManager", + "_TensorDataManager", + "_BatchManager", ] from pina._src.condition.condition_interface import ConditionInterface -from pina._src.condition.condition_base import ConditionBase +from pina._src.condition.base_condition import BaseCondition from pina._src.condition.condition import Condition from pina._src.condition.domain_equation_condition import ( DomainEquationCondition, @@ -25,3 +30,8 @@ from pina._src.condition.input_target_condition import InputTargetCondition from pina._src.condition.input_equation_condition import InputEquationCondition from pina._src.condition.data_condition import DataCondition +from pina._src.condition.batch_manager import _BatchManager +from pina._src.condition.data_manager_interface import _DataManagerInterface +from pina._src.condition.data_manager import _DataManager +from pina._src.condition.tensor_data_manager import _TensorDataManager +from pina._src.condition.graph_data_manager import _GraphDataManager From 8ba9c3c0134fcd8db522124de8d2ca55f4c27ada Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Thu, 23 Apr 2026 00:15:56 +0200 Subject: [PATCH 2/3] enhance tests for condition module --- tests/test_condition/test_data_condition.py | 565 +++++++------- .../test_domain_equation_condition.py | 58 +- .../test_input_equation_condition.py | 207 ++++-- .../test_input_target_condition.py | 698 +++++++++--------- 4 files changed, 787 insertions(+), 741 deletions(-) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index 4a88f963c..5aa6abaae 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -1,332 +1,301 @@ -import pytest import torch -from pina import Condition, LabelTensor -from pina.condition import DataCondition -from pina.graph import RadiusGraph -from torch_geometric.data import Data -from pina._src.condition.data_manager import _DataManager +import pytest +from pina.graph import RadiusGraph, Graph +from pina import LabelTensor, Condition +from pina.condition import ( + DataCondition, + _BatchManager, + _TensorDataManager, + _GraphDataManager, +) -def _create_tensor_data(use_lt=False, conditional_variables=False): - input_tensor = torch.rand((10, 3)) +# Helper function to create tensor data +def _create_tensor_data(use_lt, conditional_variables): + + # If LabelTensor is used, create tensors with labels if use_lt: - input_tensor = LabelTensor(input_tensor, ["x", "y", "z"]) - if conditional_variables: - cond_vars = torch.rand((10, 2)) - if use_lt: - cond_vars = LabelTensor(cond_vars, ["a", "b"]) - else: - cond_vars = None + input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) + cond_vars = LabelTensor(torch.rand((10, 2)), ["a", "b"]) + cond_vars = cond_vars if conditional_variables else None + + return input_tensor, cond_vars + + # Standard torch.Tensor without labels + input_tensor = torch.rand((10, 3)) + cond_vars = torch.rand((10, 2)) + cond_vars = cond_vars if conditional_variables else None + return input_tensor, cond_vars -def _create_graph_data(use_lt=False, conditional_variables=False): +# Helper function to create graph data +def _create_graph_data(use_lt, conditional_variables): + + # If LabelTensor is used, create graph data with LabelTensors if use_lt: x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) + + # Standard torch.Tensor without labels else: x = torch.rand(10, 20, 2) pos = torch.rand(10, 20, 2) - radius = 0.1 - input_graph = [ - RadiusGraph(pos=pos[i], radius=radius, x=x[i]) for i in range(len(x)) - ] - if conditional_variables: - if use_lt: - cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) - else: - cond_vars = torch.rand(10, 20, 1) - else: - cond_vars = None - return input_graph, cond_vars + cond_vars = torch.rand(10, 20, 1) + # Create a list of Graphs + graph = [RadiusGraph(pos=pos[i], radius=0.1, x=x[i]) for i in range(len(x))] -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_init_tensor_data_condition_tensor(conditional_variables): - # Setup for standard torch.Tensor - input_tensor, cond_vars = _create_tensor_data( - use_lt=False, conditional_variables=conditional_variables - ) - condition = Condition(input=input_tensor, conditional_variables=cond_vars) - - assert isinstance(condition, DataCondition) - - # Input assertions - assert isinstance(condition.input, torch.Tensor) - assert not isinstance(condition.input, LabelTensor) - - # Conditional variables assertions - if conditional_variables: - assert condition.conditional_variables is not None - assert isinstance(condition.conditional_variables, torch.Tensor) - assert not isinstance(condition.conditional_variables, LabelTensor) - else: - assert condition.conditional_variables is None + # Create conditional variables if needed + cond_vars = cond_vars if conditional_variables else None + return graph, cond_vars -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_init_tensor_data_condition_label_tensor(conditional_variables): - # Setup for LabelTensor - input_tensor, cond_vars = _create_tensor_data( - use_lt=True, conditional_variables=conditional_variables - ) - condition = Condition(input=input_tensor, conditional_variables=cond_vars) - - assert isinstance(condition, DataCondition) - - # Input assertions with label validation - assert isinstance(condition.input, LabelTensor) - assert condition.input.labels == ["x", "y", "z"] - - # Conditional variables assertions with label validation - if conditional_variables: - assert isinstance(condition.conditional_variables, LabelTensor) - assert condition.conditional_variables.labels == ["a", "b"] + +# Helper function to check tensor types +def _assert_tensor_type(t, use_lt): + if use_lt: + assert isinstance(t, LabelTensor) else: - assert condition.conditional_variables is None + assert isinstance(t, torch.Tensor) and not isinstance(t, LabelTensor) -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_init_graph_data_condition_tensor(conditional_variables): - # Setup for standard torch.Tensor - input_graph, cond_vars = _create_graph_data( - use_lt=False, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - - assert isinstance(condition, DataCondition) - - # Validate Input list - assert isinstance(condition.input, list) - for graph in condition.input: - assert isinstance(graph, Data) - assert isinstance(graph.x, torch.Tensor) - assert not isinstance(graph.x, LabelTensor) - assert isinstance(graph.pos, torch.Tensor) - - # Validate Conditional Variables - if conditional_variables: - assert isinstance(condition.conditional_variables, torch.Tensor) - assert not isinstance(condition.conditional_variables, LabelTensor) - else: - assert condition.conditional_variables is None +# Helper function to check input graph +def _assert_graph_type(graph_list, use_lt): + assert isinstance(graph_list, list) + for graph in graph_list: + _assert_tensor_type(graph.x, use_lt) +@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("conditional_variables", [False, True]) -def test_init_graph_data_condition_label_tensor(conditional_variables): - # Setup for LabelTensor - input_graph, cond_vars = _create_graph_data( - use_lt=True, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - - assert isinstance(condition, DataCondition) - - # Validate Input list and Labels - for graph in condition.input: - assert isinstance(graph.x, LabelTensor) - assert graph.x.labels == ["u", "v"] - - assert isinstance(graph.pos, LabelTensor) - assert graph.pos.labels == ["x", "y"] - - # Validate Conditional Variables and Labels - if conditional_variables: - assert isinstance(condition.conditional_variables, LabelTensor) - assert condition.conditional_variables.labels == ["f"] - else: - assert condition.conditional_variables is None +@pytest.mark.parametrize("case", ["tensor", "graph"]) +def test_constructor(case, use_lt, conditional_variables): + + # Tensor input case + if case == "tensor": + + # Define the condition + input_tensor, cond_vars = _create_tensor_data( + use_lt, conditional_variables + ) + condition = Condition( + input=input_tensor, conditional_variables=cond_vars + ) + + # Assert correct types + assert isinstance(condition, DataCondition) + _assert_tensor_type(condition.input, use_lt) + if cond_vars is not None: + _assert_tensor_type(condition.conditional_variables, use_lt) + + # Assert numerical parity + assert torch.allclose(condition.input, input_tensor) + if cond_vars is not None: + assert torch.allclose(condition.conditional_variables, cond_vars) + + # Assert labels if LabelTensor is used + if use_lt: + assert condition.input.labels == ["x", "y", "z"] + if cond_vars is not None: + assert condition.conditional_variables.labels == ["a", "b"] + + # Graph input case + elif case == "graph": + + # Define the condition + input_graph, cond_vars = _create_graph_data( + use_lt, conditional_variables + ) + condition = Condition( + input=input_graph, conditional_variables=cond_vars + ) + + # Assert correct types + assert isinstance(condition, DataCondition) + _assert_graph_type(condition.input, use_lt) + if cond_vars is not None: + _assert_tensor_type(condition.conditional_variables, use_lt) + + # Assert numerical parity for graph inputs + for i in range(len(input_graph)): + assert torch.allclose(condition.input[i].x, input_graph[i].x) + assert torch.allclose(condition.input[i].pos, input_graph[i].pos) + + # Assert numerical parity for conditional variables + if cond_vars is not None: + assert torch.allclose(condition.conditional_variables, cond_vars) + + # Assert labels if LabelTensor is used + if use_lt: + for graph in condition.input: + assert graph.x.labels == ["u", "v"] + assert graph.pos.labels == ["x", "y"] + if cond_vars is not None: + assert condition.conditional_variables.labels == ["f"] + # Prepare for invalid input tests + input_ = input_tensor if case == "tensor" else input_graph -def test_wrong_init_data_condition(): - input_tensor, cond_vars = _create_tensor_data() - # Wrong input type + # Should fail if the input is neither a tensor nor a graph with pytest.raises(ValueError): Condition(input="invalid_input", conditional_variables=cond_vars) - # Wrong conditional_variables type - with pytest.raises(ValueError): - Condition(input=input_tensor, conditional_variables="invalid_cond_vars") - # Wrong input type (list with wrong elements) - with pytest.raises(ValueError): - Condition(input=[input_tensor], conditional_variables=cond_vars) - # Wrong conditional_variables type (list) - with pytest.raises(ValueError): - Condition(input=input_tensor, conditional_variables=[cond_vars]) - - -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitem_tensor_data_condition_tensor(conditional_variables): - # Setup for standard torch.Tensor - input_tensor, cond_vars = _create_tensor_data( - use_lt=False, conditional_variables=conditional_variables - ) - condition = Condition(input=input_tensor, conditional_variables=cond_vars) - - item = condition[0] - - # Input assertions - assert isinstance(item.input, torch.Tensor) - assert not isinstance(item.input, LabelTensor) - assert item.input.shape == (3,) - - # Conditional variables assertions - if conditional_variables: - assert isinstance(item.conditional_variables, torch.Tensor) - assert item.conditional_variables.shape == (2,) - else: - assert not hasattr(item, "conditional_variables") - - -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitem_tensor_data_condition_label_tensor(conditional_variables): - # Setup for LabelTensor - input_tensor, cond_vars = _create_tensor_data( - use_lt=True, conditional_variables=conditional_variables - ) - condition = Condition(input=input_tensor, conditional_variables=cond_vars) - - item = condition[0] - - # Input assertions with label validation - assert isinstance(item.input, LabelTensor) - assert item.input.shape == (3,) - assert item.input.labels == ["x", "y", "z"] - - # Conditional variables assertions with label validation - if conditional_variables: - assert isinstance(item.conditional_variables, LabelTensor) - assert item.conditional_variables.shape == (2,) - assert item.conditional_variables.labels == ["a", "b"] - else: - assert not hasattr(item, "conditional_variables") - - -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitem_graph_data_condition_tensor(conditional_variables): - # Setup specifically for standard torch.Tensor - input_graph, cond_vars = _create_graph_data( - use_lt=False, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - - item = condition[0] - - # Assertions for the graph data - assert isinstance(item.input, Data) - assert isinstance(item.input.x, torch.Tensor) - assert not isinstance(item.input.x, LabelTensor) - assert item.input.x.shape == (20, 2) - - # Assertions for conditional variables - if conditional_variables: - assert isinstance(item.conditional_variables, torch.Tensor) - assert item.conditional_variables.shape == (1, 20, 1) - - -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitem_graph_data_condition_label_tensor(conditional_variables): - # Setup specifically for LabelTensor - input_graph, cond_vars = _create_graph_data( - use_lt=True, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - item = condition[0] - graph = item.input - - # Assertions for LabelTensor attributes - assert isinstance(graph.x, LabelTensor) - assert graph.x.labels == ["u", "v"] - assert graph.x.shape == (20, 2) - - assert isinstance(graph.pos, LabelTensor) - assert graph.pos.labels == ["x", "y"] - - # Assertions for labeled conditional variables - if conditional_variables: - cond_var = item.conditional_variables - assert isinstance(cond_var, LabelTensor) - assert cond_var.labels == ["f"] - assert cond_var.shape == (1, 20, 1) + # Should fail if the conditional_variables is neither a tensor nor a graph + with pytest.raises(ValueError): + Condition(input=input_, conditional_variables="invalid_cond_vars") + # Should fail if the input is a list of tensors + if case == "tensor": + with pytest.raises(ValueError): + Condition(input=[input_], conditional_variables=cond_vars) -@pytest.mark.parametrize("use_lt", [False, True]) -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitems_tensor_data_condition(use_lt, conditional_variables): - input_tensor, cond_vars = _create_tensor_data( - use_lt=use_lt, conditional_variables=conditional_variables - ) - condition = Condition(input=input_tensor, conditional_variables=cond_vars) - idxs = [0, 1, 3] - items = condition[idxs] - assert isinstance(items, _DataManager) - assert hasattr(items, "input") - type_ = LabelTensor if use_lt else torch.Tensor - inputs = items.input - assert isinstance(inputs, type_) - assert inputs.shape == (3, 3) - if use_lt: - assert inputs.labels == ["x", "y", "z"] - if conditional_variables: - assert hasattr(items, "conditional_variables") - cond_vars_items = items.conditional_variables - assert isinstance(cond_vars_items, type_) - assert cond_vars_items.shape == (3, 2) - if use_lt: - assert cond_vars_items.labels == ["a", "b"] - else: - assert not hasattr(items, "conditional_variables") + # Should fail if the conditional_variables is a list of tensors + if case == "tensor": + with pytest.raises(ValueError): + Condition(input=input_, conditional_variables=[cond_vars]) +@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitems_graph_data_condition_tensor(conditional_variables): - # Setup with use_lt=False - input_graph, cond_vars = _create_graph_data( - use_lt=False, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - - idxs = [0, 1, 3] - items = condition[idxs] - - # Assertions for DataManager and Graphs - assert isinstance(items, _DataManager) - graphs = items.input - assert len(graphs) == 3 - - for graph in graphs: - assert isinstance(graph.x, torch.Tensor) - assert not isinstance(graph.x, LabelTensor) - assert graph.x.shape == (20, 2) - - # Assertions for Conditional Variables - if conditional_variables: - assert isinstance(items.conditional_variables, torch.Tensor) - assert items.conditional_variables.shape == (3, 20, 1) - - +@pytest.mark.parametrize("case", ["tensor", "graph"]) +def test_get_item(case, use_lt, conditional_variables): + + # Tensor input case + if case == "tensor": + + # Define the condition + input_tensor, cond_vars = _create_tensor_data( + use_lt, conditional_variables + ) + condition = Condition( + input=input_tensor, conditional_variables=cond_vars + ) + + # Extract item using __getitem__ + index = 0 + item = condition[index] + + # Assert correct types + assert isinstance(item, _TensorDataManager) + _assert_tensor_type(item.input, use_lt) + if cond_vars is not None: + _assert_tensor_type(item.conditional_variables, use_lt) + + # Assert numerical parity + assert torch.allclose(item.input, input_tensor[index]) + if cond_vars is not None: + assert torch.allclose(item.conditional_variables, cond_vars[index]) + + # Graph input case + elif case == "graph": + + # Define the condition + input_graph, cond_vars = _create_graph_data( + use_lt, conditional_variables + ) + condition = Condition( + input=input_graph, conditional_variables=cond_vars + ) + + # Extract item using __getitem__ + index = 0 + item = condition[index] + + # Assert correct types + assert isinstance(item, _GraphDataManager) + assert isinstance(item.input, Graph) + _assert_tensor_type(item.input.x, use_lt) + if cond_vars is not None: + _assert_tensor_type(item.conditional_variables, use_lt) + + # Assert numerical parity + assert torch.allclose(item.input.x, input_graph[index].x) + assert torch.allclose(item.input.pos, input_graph[index].pos) + if cond_vars is not None: + assert torch.allclose(item.conditional_variables, cond_vars[index]) + + +@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitems_graph_data_condition_label_tensor(conditional_variables): - # Setup with use_lt=True - input_graph, cond_vars = _create_graph_data( - use_lt=True, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - - idxs = [0, 1, 3] - items = condition[idxs] - - # Assertions for LabelTensor specific attributes in Graphs - for graph in items.input: - assert isinstance(graph.x, LabelTensor) - assert graph.x.labels == ["u", "v"] - - assert isinstance(graph.pos, LabelTensor) - assert graph.pos.labels == ["x", "y"] - - # Assertions for LabelTensor in Conditional Variables - if conditional_variables: - cv = items.conditional_variables - assert isinstance(cv, LabelTensor) - assert cv.labels == ["f"] - assert cv.shape == (3, 20, 1) +@pytest.mark.parametrize("case", ["tensor", "graph"]) +def test_create_batch(case, use_lt, conditional_variables): + + # Tensor case + if case == "tensor": + input_, cond_vars = _create_tensor_data(use_lt, conditional_variables) + + # Graph case + elif case == "graph": + input_, cond_vars = _create_graph_data(use_lt, conditional_variables) + + # Define the condition + condition = Condition(input=input_, conditional_variables=cond_vars) + + # Create batches using automatic batching or condition's collate_fn + idx = [0, 2] + data_to_collate = [condition.data[i] for i in idx] + batch_auto = condition.automatic_batching_collate_fn(data_to_collate) + batch_collate = condition.collate_fn(idx, condition) + + # Check that the automatic batch has been properly created + assert isinstance(batch_auto, _BatchManager) + assert hasattr(batch_auto, "input") + if cond_vars is not None: + assert hasattr(batch_auto, "conditional_variables") + + # Check that the collate_fn batch has been properly created + assert isinstance(batch_collate, dict) + assert hasattr(batch_collate, "input") + if cond_vars is not None: + assert hasattr(batch_collate, "conditional_variables") + + # Retrieve tensor class for expected batch creation + cls = LabelTensor if use_lt else torch + + # Validate batch contents for tensor case + if case == "tensor": + + # Create expected input batch + expected_input = cls.stack([input_[i] for i in idx]) + if cond_vars is not None: + exp_cond = cls.stack([cond_vars[i] for i in idx]) + + # Assert that the automatic batch input is correct + assert torch.allclose(batch_auto.input, expected_input) + assert batch_auto.input.shape == expected_input.shape + if cond_vars is not None: + assert torch.allclose(batch_auto.conditional_variables, exp_cond) + assert batch_auto.conditional_variables.shape == exp_cond.shape + + # Assert that the collate_fn batch input is correct + assert torch.allclose(batch_collate.input, expected_input) + assert batch_collate.input.shape == expected_input.shape + if cond_vars is not None: + assert torch.allclose(batch_collate.conditional_variables, exp_cond) + assert batch_collate.conditional_variables.shape == exp_cond.shape + + # Validate batch contents for graph case + elif case == "graph": + + # Create expected input batch + expected_input = [condition.data[i].input for i in idx] + if cond_vars is not None: + exp_cond = cls.cat([cond_vars[i] for i in idx]) + + # Assert that the automatic batch input is correct + for i, graph in enumerate(expected_input): + assert torch.allclose(batch_auto.input[i].x, graph.x) + assert batch_auto.input.num_graphs == len(idx) + if cond_vars is not None: + assert torch.allclose(batch_auto.conditional_variables, exp_cond) + assert batch_auto.conditional_variables.shape == exp_cond.shape + + # Assert that the collate_fn batch input is correct + for i, graph in enumerate(expected_input): + assert torch.allclose(batch_collate.input[i].x, graph.x) + assert batch_collate.input.num_graphs == len(idx) + if cond_vars is not None: + assert torch.allclose(batch_collate.conditional_variables, exp_cond) + assert batch_collate.conditional_variables.shape == exp_cond.shape diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py index d2afbceae..760737454 100644 --- a/tests/test_condition/test_domain_equation_condition.py +++ b/tests/test_condition/test_domain_equation_condition.py @@ -4,26 +4,52 @@ from pina.equation.zoo import FixedValue from pina.condition import DomainEquationCondition -example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) -example_equation = FixedValue(0.0) +# Define a simple domain and equation for testing +domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +equation = FixedValue(0.0) -def test_init_domain_equation(): - cond = Condition(domain=example_domain, equation=example_equation) - assert isinstance(cond, DomainEquationCondition) - assert cond.domain is example_domain - assert cond.equation is example_equation - assert hasattr(cond, "data") - assert cond.data is None +def test_constructor(): -def test_len_not_implemented(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) - with pytest.raises(NotImplementedError): - len(cond) + # Define the condition + condition = Condition(domain=domain, equation=equation) + + # Assert correct types + assert isinstance(condition, DomainEquationCondition) + + # Assert that the domain and equation are stored correctly + assert condition.domain is domain + assert condition.equation is equation + + # Assert that the data attribute is set to None + assert hasattr(condition, "data") + assert condition.data is None + + # Should fail if domain is not an instance of DomainInterface or a string + with pytest.raises(ValueError): + Condition(domain=123, equation=equation) + + # Should fail if equation is not an instance of BaseEquation + with pytest.raises(ValueError): + Condition(domain=domain, equation=123) -def test_getitem_not_implemented(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) +def test_get_item(): + + # Define the condition + condition = Condition(domain=domain, equation=equation) + + # Should raise NotImplementedError when trying to access by index with pytest.raises(NotImplementedError): - cond[0] + condition[0] + + +def test_create_batch(): + + # Define the condition + condition = Condition(domain=domain, equation=equation) + + # Should raise TypeError when trying to access condition.data since None + with pytest.raises(TypeError): + _ = [condition.data[i] for i in [0, 2, 4, 6]] diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index 4bed448b5..a6366e86a 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -1,79 +1,160 @@ import torch import pytest -from pina import Condition -from pina._src.condition.input_equation_condition import InputEquationCondition from pina.equation import Equation -from pina import LabelTensor -from pina.graph import Graph -from pina._src.condition.data_manager import _DataManager +from pina import LabelTensor, Condition +from pina.graph import RadiusGraph, Graph +from pina.condition import ( + InputEquationCondition, + _TensorDataManager, + _GraphDataManager, + _BatchManager, +) + +# Generate input and equation data for testing - tensor case +input_tensor = LabelTensor(torch.rand((10, 2)), ["x", "y"]) +equation_tensor = Equation(lambda pts: pts["x"] ** 2 + pts["y"] ** 2 - 1) + +# Generate input and equation data for testing - graph case +input_graph_list = [ + RadiusGraph( + x=LabelTensor(torch.rand(10, 2), labels=["u", "v"]), + pos=LabelTensor(torch.rand(10, 2), labels=["x", "y"]), + radius=0.1, + edge_attr=True, + ) + for _ in range(3) +] +equation_graph = Equation(lambda pts: pts.x["u"] ** 2 + pts.x["v"] ** 2 - 1) + + +@pytest.mark.parametrize("case", ["tensor", "graph"]) +def test_constructor(case): + + # Tensor case + if case == "tensor": + input_, equation = input_tensor, equation_tensor + + # Graph case + elif case == "graph": + input_, equation = input_graph_list, equation_graph + + # Define the condition + condition = Condition(input=input_, equation=equation) + + # Assert correct types + assert isinstance(condition, InputEquationCondition) + # Assert that the equation is stored correctly + assert condition.equation is equation -def _create_pts_and_equation(): - def dummy_equation(pts): - return pts["x"] ** 2 + pts["y"] ** 2 - 1 + # Assert correct input type + if case == "tensor": + assert isinstance(condition.input, LabelTensor) + elif case == "graph": + assert isinstance(condition.input, list) + for graph in condition.input: + assert isinstance(graph, Graph) - pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) - equation = Equation(dummy_equation) - return pts, equation + # Should fail if input is not an instance of LabelTensor or Graph + with pytest.raises(ValueError): + Condition(input=torch.rand(10, 2), equation=equation) + # Should fail if equation is not an instance of BaseEquation + with pytest.raises(ValueError): + Condition(input=input_, equation="not_an_equation") -def _create_graph_and_equation(): - from pina.graph import KNNGraph + # Should fail if input is a list with wrong elements + with pytest.raises(ValueError): + Condition( + input=[LabelTensor(torch.rand(10, 2), ["x", "y"])], + equation=equation, + ) - def dummy_equation(pts): - return pts.x[:, 0] ** 2 + pts.x[:, 1] ** 2 - 1 - x = LabelTensor(torch.randn(100, 2), labels=["u", "v"]) - pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) - graph = KNNGraph(x=x, pos=pos, neighbours=5, edge_attr=True) - equation = Equation(dummy_equation) - return graph, equation +@pytest.mark.parametrize("case", ["tensor", "graph"]) +def test_get_item(case): + # Tensor case + if case == "tensor": + input_, equation = input_tensor, equation_tensor -def test_init_tensor_equation_condition(): - pts, equation = _create_pts_and_equation() - condition = Condition(input=pts, equation=equation) - assert isinstance(condition, InputEquationCondition) - assert condition.input.shape == (100, 2) - assert condition.equation is equation + # Graph case + elif case == "graph": + input_, equation = input_graph_list, equation_graph + # Define the condition + condition = Condition(input=input_, equation=equation) -def test_init_graph_equation_condition(): - graph, equation = _create_graph_and_equation() - condition = Condition(input=graph, equation=equation) - assert isinstance(condition, InputEquationCondition) - assert isinstance(condition.input, Graph) - assert condition.input.x.shape == (100, 2) - assert condition.equation is equation + # Extract item using __getitem__ + index = 0 + item = condition[index] + # Assert correct types and numerical parity + if case == "tensor": + assert isinstance(item, _TensorDataManager) + assert isinstance(item.input, LabelTensor) + assert torch.allclose(item.input, input_[index]) -def test_wrong_init_equation_condition(): - pts, equation = _create_pts_and_equation() - # Wrong input type - with pytest.raises(ValueError): - Condition(input=torch.randn(10, 2), equation=equation) - # Wrong equation type - with pytest.raises(ValueError): - Condition(input=pts, equation="not_an_equation") - # Wrong input type (list with wrong elements) - with pytest.raises(ValueError): - Condition(input=[torch.randn(10, 2)], equation=equation) - - -def test_getitem_tensor_equation_condition(): - pts, equation = _create_pts_and_equation() - condition = Condition(input=pts, equation=equation) - item = condition[0] - assert isinstance(item, _DataManager) - assert hasattr(item, "input") - assert item.input.shape == (2,) - - -def test_getitems_tensor_equation_condition(): - pts, equation = _create_pts_and_equation() - condition = Condition(input=pts, equation=equation) - idxs = [0, 1, 3] - item = condition[idxs] - assert isinstance(item, _DataManager) - assert hasattr(item, "input") - assert item.input.shape == (3, 2) + elif case == "graph": + assert isinstance(item, _GraphDataManager) + assert isinstance(item.input, Graph) + assert torch.allclose(item.input.x, input_[index].x) + + +@pytest.mark.parametrize("case", ["tensor", "graph"]) +def test_create_batch(case): + + # Tensor case + if case == "tensor": + input_, equation = input_tensor, equation_tensor + + # Graph case + elif case == "graph": + input_, equation = input_graph_list, equation_graph + + # Define the condition + condition = Condition(input=input_, equation=equation) + + # Create batches using automatic batching or condition's collate_fn + idx = [0, 2] + data_to_collate = [condition.data[i] for i in idx] + batch_auto = condition.automatic_batching_collate_fn(data_to_collate) + batch_collate = condition.collate_fn(idx, condition) + + # Check that the automatic batch has been properly created + assert isinstance(batch_auto, (_BatchManager)) + assert hasattr(batch_auto, "input") + + # Check that the collate_fn batch has been properly created + assert isinstance(batch_collate, dict) + assert hasattr(batch_collate, "input") + + # Validate batch contents for tensor case + if case == "tensor": + + # Create expected input batch + expected_input = LabelTensor.stack([input_[i] for i in idx]) + + # Assert that the automatic batch input is correct + assert torch.allclose(batch_auto.input, expected_input) + assert batch_auto.input.shape == expected_input.shape + + # Assert that the collate_fn batch input is correct + assert torch.allclose(batch_collate.input, expected_input) + assert batch_collate.input.shape == expected_input.shape + + # Validate batch contents for graph case + elif case == "graph": + + # Create expected input batch + expected_input = [condition.data[i].input for i in idx] + + # Assert that the automatic batch input is correct + for i, graph in enumerate(expected_input): + assert torch.allclose(batch_auto.input[i].x, graph.x) + assert batch_auto.input.num_graphs == len(idx) + + # Assert that the collate_fn batch input is correct + for i, graph in enumerate(expected_input): + assert torch.allclose(batch_collate.input[i].x, graph.x) + assert batch_collate.input.num_graphs == len(idx) diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index 1f469f0cd..8352ee1c3 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -1,409 +1,379 @@ import torch import pytest +from pina.graph import RadiusGraph, Graph from pina import LabelTensor, Condition -from pina.graph import RadiusGraph -from pina._src.condition.batch_manager import _BatchManager +from pina.condition import ( + InputTargetCondition, + _BatchManager, + _TensorDataManager, + _GraphDataManager, +) -def _create_tensor_data(use_lt=False): +# Helper function to create tensor data +def _create_tensor_data(use_lt): + + # If LabelTensor is used, create tensors with labels if use_lt: input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) target_tensor = LabelTensor(torch.rand((10, 2)), ["a", "b"]) return input_tensor, target_tensor + + # Standard torch.Tensor without labels input_tensor = torch.rand((10, 3)) target_tensor = torch.rand((10, 2)) + return input_tensor, target_tensor -def _create_graph_data(tensor_input=True, use_lt=False): +# Helper function to create graph data +def _create_graph_data(is_input, use_lt): + + # If LabelTensor is used, create graph data with LabelTensors if use_lt: x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + + # Standard torch.Tensor without labels else: x = torch.rand(10, 20, 2) pos = torch.rand(10, 20, 2) - radius = 0.1 + tensor = torch.rand(10, 20, 1) + + # Create a list of Graphs graph = [ RadiusGraph( pos=pos[i], - radius=radius, - x=x[i] if not tensor_input else None, - y=x[i] if tensor_input else None, + radius=0.1, + x=x[i] if is_input else None, + y=x[i] if not is_input else None, ) for i in range(len(x)) ] + + return graph, tensor + + +# Helper function to check tensor types +def _assert_tensor_type(t, use_lt): if use_lt: - tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + assert isinstance(t, LabelTensor) else: - tensor = torch.rand(10, 20, 1) - return graph, tensor + assert isinstance(t, torch.Tensor) and not isinstance(t, LabelTensor) -def test_init_tensor_input_tensor_target_condition_tensor(): - # Setup for standard torch.Tensor - input_tensor, target_tensor = _create_tensor_data(use_lt=False) - condition = Condition(input=input_tensor, target=target_tensor) - - # Numerical assertions - assert torch.allclose( - condition.input, input_tensor - ), "Standard input tensor equality failed" - assert torch.allclose( - condition.target, target_tensor - ), "Standard target tensor equality failed" - - # Type assertions - assert isinstance(condition.input, torch.Tensor) - assert not isinstance(condition.input, LabelTensor) - assert isinstance(condition.target, torch.Tensor) - assert not isinstance(condition.target, LabelTensor) - - -def test_init_tensor_input_tensor_target_condition_label_tensor(): - # Setup for LabelTensor - input_tensor, target_tensor = _create_tensor_data(use_lt=True) - condition = Condition(input=input_tensor, target=target_tensor) - - # Type and Label assertions for Input - assert isinstance( - condition.input, LabelTensor - ), "Input did not preserve LabelTensor type" - assert condition.input.labels == [ - "x", - "y", - "z", - ], "Input labels were lost or corrupted" - - # Type and Label assertions for Target - assert isinstance( - condition.target, LabelTensor - ), "Target did not preserve LabelTensor type" - assert condition.target.labels == [ - "a", - "b", - ], "Target labels were lost or corrupted" - - # Numerical parity check still applies - assert torch.allclose(condition.input, input_tensor) - assert torch.allclose(condition.target, target_tensor) - - -def test_init_tensor_input_graph_target_condition_tensor(): - # Setup for standard torch.Tensor - target_graph, input_tensor = _create_graph_data(use_lt=False) - condition = Condition(input=input_tensor, target=target_graph) - - # Input assertions (Tensor) - assert isinstance(condition.input, torch.Tensor) - assert not isinstance(condition.input, LabelTensor) - assert torch.allclose(condition.input, input_tensor) - - # Target assertions (Graph List) - assert isinstance(condition.target, list) - for i, graph in enumerate(target_graph): - assert isinstance(condition.target[i].y, torch.Tensor) - assert not isinstance(condition.target[i].y, LabelTensor) - assert torch.allclose(condition.target[i].y, graph.y) - - -def test_init_tensor_input_graph_target_condition_label_tensor(): - # Setup for LabelTensor - target_graph, input_tensor = _create_graph_data(use_lt=True) - condition = Condition(input=input_tensor, target=target_graph) - - # Input assertions with label validation - assert isinstance(condition.input, LabelTensor) - assert condition.input.labels == ["f"] - assert torch.allclose(condition.input, input_tensor) - - # Target assertions with nested label validation - for i, graph in enumerate(target_graph): - target_y = condition.target[i].y - assert isinstance(target_y, LabelTensor) - assert target_y.labels == ["u", "v"] - assert torch.allclose(target_y, graph.y) - - -def test_init_graph_input_tensor_target_condition_tensor(): - # Setup for standard torch.Tensor (use_lt=False) - input_graph, target_tensor = _create_graph_data(False, use_lt=False) - condition = Condition(input=input_graph, target=target_tensor) - - # Input assertions: Check graph list integrity - assert isinstance(condition.input, list) - for i, original_graph in enumerate(input_graph): - assert torch.allclose(condition.input[i].x, original_graph.x) - assert isinstance(condition.input[i].x, torch.Tensor) - assert not isinstance(condition.input[i].x, LabelTensor) - - # Target assertions: Check raw tensor integrity - assert torch.allclose(condition.target, target_tensor) - assert isinstance(condition.target, torch.Tensor) - assert not isinstance(condition.target, LabelTensor) - - -def test_init_graph_input_tensor_target_condition_label_tensor(): - # Setup for LabelTensor (use_lt=True) - input_graph, target_tensor = _create_graph_data(False, use_lt=True) - condition = Condition(input=input_graph, target=target_tensor) - - # Input assertions: Check LabelTensor preservation in Graphs - for i, original_graph in enumerate(input_graph): - input_x = condition.input[i].x - assert isinstance(input_x, LabelTensor) - assert input_x.labels == original_graph.x.labels - assert torch.allclose(input_x, original_graph.x) - - # Target assertions: Check LabelTensor preservation in Target - assert isinstance(condition.target, LabelTensor) - assert condition.target.labels == ["f"] - assert torch.allclose(condition.target, target_tensor) - - -def test_wrong_init(): - input_tensor, target_tensor = _create_tensor_data() - with pytest.raises(ValueError): - Condition(input="invalid_input", target=target_tensor) - with pytest.raises(ValueError): - Condition(input=input_tensor, target="invalid_target") +# Helper function to check input graph +def _assert_graph_type(graph_list, use_lt, is_input): + + assert isinstance(graph_list, list) + for graph in graph_list: + value = graph.x if is_input else graph.y + _assert_tensor_type(value, use_lt) + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize( + "case", [["tensor", "tensor"], ["tensor", "graph"], ["graph", "tensor"]] +) +def test_constructor(use_lt, case): + + # Tensor - tensor + if case == ["tensor", "tensor"]: + + # Define the condition + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + + # Assert correct types + assert isinstance(condition, InputTargetCondition) + _assert_tensor_type(condition.input, use_lt) + _assert_tensor_type(condition.target, use_lt) + + # Assert numerical parity + assert torch.allclose(condition.input, input_tensor) + assert torch.allclose(condition.target, target_tensor) + + # Assert labels if LabelTensor is used + if use_lt: + assert condition.input.labels == ["x", "y", "z"] + assert condition.target.labels == ["a", "b"] + + # Tensor - graph + elif case == ["tensor", "graph"]: + + # Define the condition + target_graph, input_tensor = _create_graph_data( + is_input=False, use_lt=use_lt + ) + condition = Condition(input=input_tensor, target=target_graph) + + # Assert correct types + assert isinstance(condition, InputTargetCondition) + _assert_tensor_type(condition.input, use_lt) + _assert_graph_type(condition.target, use_lt, is_input=False) + + # Assert numerical parity + assert torch.allclose(condition.input, input_tensor) + for i, graph in enumerate(target_graph): + assert torch.allclose(condition.target[i].y, graph.y) + + # Assert labels if LabelTensor is used + if use_lt: + assert condition.input.labels == ["f"] + for i in range(len(target_graph)): + assert condition.target[i].y.labels == ["u", "v"] + assert condition.target[i].pos.labels == ["x", "y"] + + # Graph - tensor + elif case == ["graph", "tensor"]: + + # Define the condition + input_graph, target_tensor = _create_graph_data( + is_input=True, use_lt=use_lt + ) + condition = Condition(input=input_graph, target=target_tensor) + + # Assert correct types + assert isinstance(condition, InputTargetCondition) + _assert_graph_type(condition.input, use_lt, is_input=True) + _assert_tensor_type(condition.target, use_lt) + + # Assert numerical parity + assert torch.allclose(condition.target, target_tensor) + for i, graph in enumerate(input_graph): + assert torch.allclose(condition.input[i].x, graph.x) + + # Assert labels if LabelTensor is used + if use_lt: + assert condition.target.labels == ["f"] + for i in range(len(input_graph)): + assert condition.input[i].x.labels == ["u", "v"] + assert condition.input[i].pos.labels == ["x", "y"] + + # Prepare for invalid input tests + input_ = input_tensor if case[0] == "tensor" else input_graph + target_ = target_tensor if case[1] == "tensor" else target_graph + + # Should fail if the input is neither a tensor nor a graph with pytest.raises(ValueError): - Condition(input=[input_tensor], target=target_tensor) + Condition(input="invalid_input", target=target_) + + # Should fail if the target is neither a tensor nor a graph with pytest.raises(ValueError): - Condition(input=input_tensor, target=[target_tensor]) + Condition(input=input_, target="invalid_target") + + # Should fail if the input is a list of tensors + if case[0] == "tensor": + with pytest.raises(ValueError): + Condition(input=[input_], target=target_) + # Should fail if the target is a list of tensors + if case[1] == "tensor": + with pytest.raises(ValueError): + Condition(input=input_, target=[target_]) -def test_getitem_tensor_input_tensor_target_condition_tensor(): - # Setup for standard torch.Tensor - input_tensor, target_tensor = _create_tensor_data(use_lt=False) - condition = Condition(input=input_tensor, target=target_tensor) - # We test a single index to verify __getitem__ logic - index = 0 - item = condition[index] +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize( + "case", [["tensor", "tensor"], ["tensor", "graph"], ["graph", "tensor"]] +) +def test_get_item(use_lt, case): - # Numerical and Type Assertions - assert torch.allclose(item.input, input_tensor[index]) - assert isinstance(item.input, torch.Tensor) - assert not isinstance(item.input, LabelTensor) + # Tensor - tensor + if case == ["tensor", "tensor"]: - assert torch.allclose(item.target, target_tensor[index]) - assert isinstance(item.target, torch.Tensor) - assert not isinstance(item.target, LabelTensor) + # Define the condition + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + # Extract item using __getitem__ + index = 0 + item = condition[index] -def test_getitem_tensor_input_tensor_target_condition_label_tensor(): - # Setup for LabelTensor - input_tensor, target_tensor = _create_tensor_data(use_lt=True) - condition = Condition(input=input_tensor, target=target_tensor) + # Assert correct types + assert isinstance(item, _TensorDataManager) + _assert_tensor_type(item.input, use_lt) + _assert_tensor_type(item.target, use_lt) - index = 0 - item = condition[index] + # Assert numerical parity + assert torch.allclose(item.input, input_tensor[index]) + assert torch.allclose(item.target, target_tensor[index]) - # Verify Input LabelTensor preservation - assert isinstance(item.input, LabelTensor) - assert item.input.labels == input_tensor.labels - assert torch.allclose(item.input, input_tensor[index]) + # Tensor - graph + elif case == ["tensor", "graph"]: - # Verify Target LabelTensor preservation - assert isinstance(item.target, LabelTensor) - assert item.target.labels == target_tensor.labels - assert torch.allclose(item.target, target_tensor[index]) + # Define the condition + target_graph, input_tensor = _create_graph_data( + is_input=False, use_lt=use_lt + ) + condition = Condition(input=input_tensor, target=target_graph) + # Extract item using __getitem__ + index = 0 + item = condition[index] -@pytest.mark.parametrize("use_lt", [True, False]) -def test_getitem_graph_input_tensor_target_condition(use_lt): - input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) - condition = Condition(input=input_graph, target=target_tensor) - assert len(condition) == len(input_graph) - for i in range(len(input_graph)): - item = condition[i] - assert torch.allclose( - item.input.x, input_graph[i].x - ), "GraphInputTensorTargetCondition __getitem__ input failed" - assert torch.allclose( - item.target, target_tensor[i] - ), "GraphInputTensorTargetCondition __getitem__ target failed" - if use_lt: - assert isinstance( - item.input.x, LabelTensor - ), "GraphInputTensorTargetCondition __getitem__ input type failed" - assert ( - item.input.x.labels == input_graph[i].x.labels - ), "GraphInputTensorTargetCondition __getitem__ input labels failed" - assert isinstance( - item.target, LabelTensor - ), "GraphInputTensorTargetCondition __getitem__ target type failed" - assert item.target.labels == [ - "f" - ], "GraphInputTensorTargetCondition __getitem__ target labels failed" - - -def test_getitem_tensor_input_graph_target_condition_tensor(): - # Setup for standard torch.Tensor - target_graph, input_tensor = _create_graph_data(use_lt=False) - condition = Condition(input=input_tensor, target=target_graph) - - # Check first item indexing - idx = 0 - item = condition[idx] + # Assert correct types + assert isinstance(item, _GraphDataManager) + _assert_tensor_type(item.input, use_lt) + assert isinstance(item.target, Graph) + _assert_tensor_type(item.target.y, use_lt) + + # Assert numerical parity + assert torch.allclose(item.input, input_tensor[index]) + assert torch.allclose(item.target.y, target_graph[index].y) + + # Graph - tensor + elif case == ["graph", "tensor"]: + + # Define the condition + input_graph, target_tensor = _create_graph_data( + is_input=True, use_lt=use_lt + ) + condition = Condition(input=input_graph, target=target_tensor) - # Input assertions (Tensor) - assert torch.allclose(item.input, input_tensor[idx]) - assert isinstance(item.input, torch.Tensor) - assert not isinstance(item.input, LabelTensor) - - # Target assertions (Graph Data) - assert torch.allclose(item.target.y, target_graph[idx].y) - assert isinstance(item.target.y, torch.Tensor) - assert not isinstance(item.target.y, LabelTensor) - - -def test_getitem_tensor_input_graph_target_condition_label_tensor(): - # Setup for LabelTensor - target_graph, input_tensor = _create_graph_data(use_lt=True) - condition = Condition(input=input_tensor, target=target_graph) - - idx = 0 - item = condition[idx] - - # Input LabelTensor validation - assert isinstance(item.input, LabelTensor) - assert item.input.labels == input_tensor.labels - assert torch.allclose(item.input, input_tensor[idx]) - - # Target Graph LabelTensor validation - target_y = item.target.y - assert isinstance(target_y, LabelTensor) - assert target_y.labels == ["u", "v"] - assert torch.allclose(target_y, target_graph[idx].y) - - -def test_getitems_tensor_input_tensor_target_condition_tensor(): - # Setup for standard torch.Tensor - input_tensor, target_tensor = _create_tensor_data(use_lt=False) - condition = Condition(input=input_tensor, target=target_tensor) - - indices = [1, 3, 5, 7] - items = condition[indices] - - # Verify values by comparing against manually stacked slices - expected_input = torch.stack([input_tensor[i] for i in indices]) - expected_target = torch.stack([target_tensor[i] for i in indices]) - - assert torch.allclose(items.input, expected_input) - assert torch.allclose(items.target, expected_target) - - # Ensure types remain standard torch.Tensor - assert isinstance(items.input, torch.Tensor) - assert not isinstance(items.input, LabelTensor) - assert isinstance(items.target, torch.Tensor) - - -def test_getitems_tensor_input_tensor_target_condition_label_tensor(): - # Setup for LabelTensor - input_tensor, target_tensor = _create_tensor_data(use_lt=True) - condition = Condition(input=input_tensor, target=target_tensor) - - indices = [1, 3, 5, 7] - items = condition[indices] - - # Assertions for Input LabelTensor - assert isinstance(items.input, LabelTensor) - assert items.input.labels == ["x", "y", "z"] - assert torch.allclose(items.input, input_tensor[indices]) - - # Assertions for Target LabelTensor - assert isinstance(items.target, LabelTensor) - assert items.target.labels == ["a", "b"] - assert torch.allclose(items.target, target_tensor[indices]) - - -def test_getitems_tensor_input_graph_target_condition_tensor(): - # Setup for standard torch.Tensor - target_graph, input_tensor = _create_graph_data(True, use_lt=False) - condition = Condition(input=input_tensor, target=target_graph) - - indices = [0, 2, 4] - items = condition[indices] - - # 1. Verify Input Batch (Tensor) - expected_input = torch.stack([input_tensor[i] for i in indices]) - assert torch.allclose(items.input, expected_input) - assert isinstance(items.input, torch.Tensor) - assert not isinstance(items.input, LabelTensor) - - # 2. Verify Target Batch (Graph List) - assert len(items.target) == len(indices) - for i, original_idx in enumerate(indices): - assert torch.allclose(items.target[i].y, target_graph[original_idx].y) - assert isinstance(items.target[i].y, torch.Tensor) - - -def test_getitems_tensor_input_graph_target_condition_label_tensor(): - # Setup for LabelTensor - target_graph, input_tensor = _create_graph_data(True, use_lt=True) - condition = Condition(input=input_tensor, target=target_graph) - - indices = [0, 2, 4] - items = condition[indices] - - # 1. Verify Input LabelTensor preservation - assert isinstance(items.input, LabelTensor) - assert items.input.labels == ["f"] - # Verify values still match - assert torch.allclose(items.input, input_tensor[indices]) - - # 2. Verify Target Graphs LabelTensor preservation - assert len(items.target) == len(indices) - for i, original_idx in enumerate(indices): - target_y = items.target[i].y - assert isinstance(target_y, LabelTensor) - assert target_y.labels == ["u", "v"] - # Verify numerical parity - assert torch.allclose(target_y, target_graph[original_idx].y) - - -def test_create_batch_tensor(): - input_tensor, target_tensor = _create_tensor_data() - condition = Condition(input=input_tensor, target=target_tensor) - idx = [0, 2, 4, 6] - data_to_collate = [condition.data[i] for i in idx] - batch = condition.automatic_batching_collate_fn(data_to_collate) - assert isinstance(batch, _BatchManager) - assert hasattr(batch, "input") - assert hasattr(batch, "target") - expected_input = torch.stack([input_tensor[i] for i in idx]) - expected_target = torch.stack([target_tensor[i] for i in idx]) - assert torch.allclose(batch.input, expected_input) - assert torch.allclose(batch.target, expected_target) - - batch = condition.collate_fn(idx, condition) - # assert isinstance(batch, _BatchManager) - assert hasattr(batch, "input") - assert hasattr(batch, "target") - expected_input = torch.stack([input_tensor[i] for i in idx]) - expected_target = torch.stack([target_tensor[i] for i in idx]) - assert torch.allclose(batch.input, expected_input) - assert torch.allclose(batch.target, expected_target) - - -def test_create_batch_graph(): - input_graph, target_tensor = _create_graph_data(False) - condition = Condition(input=input_graph, target=target_tensor) - idx = [1, 3, 5] - data_to_collate = [condition.data[i] for i in idx] - batch = condition.automatic_batching_collate_fn(data_to_collate) - assert isinstance(batch, _BatchManager) - assert hasattr(batch, "input") - assert hasattr(batch, "target") - expected_target = torch.cat([target_tensor[i] for i in idx]) - print(expected_target.shape, batch.target.shape) - assert torch.allclose(batch.target, expected_target) - assert batch.input.num_graphs == len(idx) - - batch = condition.collate_fn(idx, condition) - assert isinstance(batch, _BatchManager) - assert hasattr(batch, "input") - assert hasattr(batch, "target") - assert torch.allclose(batch.target, expected_target) - assert batch.input.num_graphs == len(idx) + # Extract item using __getitem__ + index = 0 + item = condition[index] + + # Assert correct types + assert isinstance(item, _GraphDataManager) + assert isinstance(item.input, Graph) + _assert_tensor_type(item.input.x, use_lt) + _assert_tensor_type(item.target, use_lt) + + # Assert numerical parity + assert torch.allclose(item.target, target_tensor[index]) + assert torch.allclose(item.input.x, input_graph[index].x) + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize( + "case", [["tensor", "tensor"], ["tensor", "graph"], ["graph", "tensor"]] +) +def test_create_batch(use_lt, case): + + # Tensor - tensor + if case == ["tensor", "tensor"]: + + # Define the condition + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + + # Create batches using automatic batching or condition's collate_fn + idx = [0, 2] + data_to_collate = [condition.data[i] for i in idx] + batch_auto = condition.automatic_batching_collate_fn(data_to_collate) + batch_collate = condition.collate_fn(idx, condition) + + # Check that the automatic batch has been properly created + assert isinstance(batch_auto, _BatchManager) + assert hasattr(batch_auto, "input") + assert hasattr(batch_auto, "target") + + # Check that the collate_fn batch has been properly created + assert isinstance(batch_collate, dict) + assert hasattr(batch_collate, "input") + assert hasattr(batch_collate, "target") + + # Create expected input and target batches + expected_input = torch.stack([input_tensor[i] for i in idx]) + expected_target = torch.stack([target_tensor[i] for i in idx]) + + # Assert that the automatic batch input and target are correct + assert torch.allclose(batch_auto.input, expected_input) + assert torch.allclose(batch_auto.target, expected_target) + assert batch_auto.input.shape == expected_input.shape + assert batch_auto.target.shape == expected_target.shape + + # Assert that the collate_fn batch input and target are correct + assert torch.allclose(batch_collate.input, expected_input) + assert torch.allclose(batch_collate.target, expected_target) + assert batch_collate.input.shape == expected_input.shape + assert batch_collate.target.shape == expected_target.shape + + # Tensor - graph + elif case == ["tensor", "graph"]: + + # Define the condition + target_graph, input_tensor = _create_graph_data( + is_input=False, use_lt=use_lt + ) + condition = Condition(input=input_tensor, target=target_graph) + + # Create batches using automatic batching or condition's collate_fn + idx = [0, 2] + data_to_collate = [condition.data[i] for i in idx] + batch_auto = condition.automatic_batching_collate_fn(data_to_collate) + batch_collate = condition.collate_fn(idx, condition) + + # Check that the automatic batch has been properly created + assert isinstance(batch_auto, _BatchManager) + assert hasattr(batch_auto, "input") + assert hasattr(batch_auto, "target") + + # Check that the collate_fn batch has been properly created + assert isinstance(batch_collate, dict) + assert hasattr(batch_collate, "input") + assert hasattr(batch_collate, "target") + + # Create expected input and target batches + expected_input = torch.cat([input_tensor[i] for i in idx]) + expected_target = [target_graph[i] for i in idx] + + # Assert that the automatic batch input and target are correct + assert torch.allclose(batch_auto.input, expected_input) + for i, graph in enumerate(expected_target): + assert torch.allclose(batch_auto.target[i].y, graph.y) + assert batch_auto.input.shape == expected_input.shape + assert batch_auto.target.num_graphs == len(idx) + + # Assert that the collate_fn batch input and target are correct + assert torch.allclose(batch_collate.input, expected_input) + for i, graph in enumerate(expected_target): + assert torch.allclose(batch_collate.target[i].y, graph.y) + assert batch_collate.input.shape == expected_input.shape + assert batch_collate.target.num_graphs == len(idx) + + # Graph - tensor + elif case == ["graph", "tensor"]: + + # Define the condition + input_graph, target_tensor = _create_graph_data( + is_input=True, use_lt=use_lt + ) + condition = Condition(input=input_graph, target=target_tensor) + + # Create batches using automatic batching or condition's collate_fn + idx = [0, 2] + data_to_collate = [condition.data[i] for i in idx] + batch_auto = condition.automatic_batching_collate_fn(data_to_collate) + batch_collate = condition.collate_fn(idx, condition) + + # Check that the automatic batch has been properly created + assert isinstance(batch_auto, _BatchManager) + assert hasattr(batch_auto, "input") + assert hasattr(batch_auto, "target") + + # Check that the collate_fn batch has been properly created + assert isinstance(batch_collate, dict) + assert hasattr(batch_collate, "input") + assert hasattr(batch_collate, "target") + + # Create expected input and target batches + expected_input = [input_graph[i] for i in idx] + expected_target = torch.cat([target_tensor[i] for i in idx]) + + # Assert that the automatic batch input and target are correct + for i, graph in enumerate(expected_input): + assert torch.allclose(batch_auto.input[i].x, graph.x) + assert torch.allclose(batch_auto.target, expected_target) + assert batch_auto.input.num_graphs == len(idx) + assert batch_auto.target.shape == expected_target.shape + + # Assert that the collate_fn batch input and target are correct + for i, graph in enumerate(expected_input): + assert torch.allclose(batch_collate.input[i].x, graph.x) + assert torch.allclose(batch_collate.target, expected_target) + assert batch_collate.input.num_graphs == len(idx) + assert batch_collate.target.shape == expected_target.shape From 2845c5af12975dafbcc0fdfd0808ee760fa03c00 Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Thu, 23 Apr 2026 10:50:26 +0200 Subject: [PATCH 3/3] move data managers to data/manager submodule --- docs/source/_rst/_code.rst | 18 +-- docs/source/_rst/condition/batch_manager.rst | 9 -- docs/source/_rst/condition/data_manager.rst | 9 -- .../_rst/condition/data_manager_interface.rst | 9 -- .../_rst/condition/graph_data_manager.rst | 9 -- .../_rst/condition/tensor_data_manager.rst | 9 -- .../_rst/data/manager/batch_manager.rst | 9 ++ .../source/_rst/data/manager/data_manager.rst | 9 ++ .../data/manager/data_manager_interface.rst | 9 ++ .../_rst/data/manager/graph_data_manager.rst | 9 ++ .../_rst/data/manager/tensor_data_manager.rst | 9 ++ pina/_src/condition/data_condition.py | 2 +- .../condition/input_equation_condition.py | 2 +- pina/_src/condition/input_target_condition.py | 2 +- .../manager}/batch_manager.py | 0 .../manager}/data_manager.py | 14 +- .../manager}/data_manager_interface.py | 2 +- .../manager}/graph_data_manager.py | 4 +- .../manager}/tensor_data_manager.py | 4 +- pina/condition/__init__.py | 10 -- pina/data/manager.py | 15 ++ tests/test_condition/test_data_condition.py | 6 +- .../test_input_equation_condition.py | 5 +- .../test_input_target_condition.py | 6 +- .../test_data_module.py} | 0 tests/test_data/test_graph_data_manager.py | 115 +++++++++++++++ tests/test_data/test_tensor_data_manager.py | 54 +++++++ tests/test_data_manager.py | 137 ------------------ 28 files changed, 262 insertions(+), 224 deletions(-) delete mode 100644 docs/source/_rst/condition/batch_manager.rst delete mode 100644 docs/source/_rst/condition/data_manager.rst delete mode 100644 docs/source/_rst/condition/data_manager_interface.rst delete mode 100644 docs/source/_rst/condition/graph_data_manager.rst delete mode 100644 docs/source/_rst/condition/tensor_data_manager.rst create mode 100644 docs/source/_rst/data/manager/batch_manager.rst create mode 100644 docs/source/_rst/data/manager/data_manager.rst create mode 100644 docs/source/_rst/data/manager/data_manager_interface.rst create mode 100644 docs/source/_rst/data/manager/graph_data_manager.rst create mode 100644 docs/source/_rst/data/manager/tensor_data_manager.rst rename pina/_src/{condition => data/manager}/batch_manager.py (100%) rename pina/_src/{condition => data/manager}/data_manager.py (73%) rename pina/_src/{condition => data/manager}/data_manager_interface.py (96%) rename pina/_src/{condition => data/manager}/graph_data_manager.py (98%) rename pina/_src/{condition => data/manager}/tensor_data_manager.py (95%) create mode 100644 pina/data/manager.py rename tests/{test_datamodule.py => test_data/test_data_module.py} (100%) create mode 100644 tests/test_data/test_graph_data_manager.py create mode 100644 tests/test_data/test_tensor_data_manager.py delete mode 100644 tests/test_data_manager.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 7433ab5a1..704298020 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -65,11 +65,11 @@ Batch and Data Managers .. toctree:: :titlesonly: - Batch Manager - Data Manager Interface - Data Manager - Graph Data Manager - Tensor Data Manager + Batch Manager + Data Manager Interface + Data Manager + Graph Data Manager + Tensor Data Manager Solvers -------------- @@ -80,8 +80,8 @@ Solvers SolverInterface SingleSolverInterface MultiSolverInterface - SupervisedSolverInterface - DeepEnsembleSolverInterface + SupervisedSolverInterface + DeepEnsembleSolverInterface PINNInterface PINN GradientPINN @@ -89,9 +89,9 @@ Solvers CompetitivePINN SelfAdaptivePINN RBAPINN - DeepEnsemblePINN + DeepEnsemblePINN SupervisedSolver - DeepEnsembleSupervisedSolver + DeepEnsembleSupervisedSolver ReducedOrderModelSolver GAROM AutoregressiveSolverInterface diff --git a/docs/source/_rst/condition/batch_manager.rst b/docs/source/_rst/condition/batch_manager.rst deleted file mode 100644 index f651260bf..000000000 --- a/docs/source/_rst/condition/batch_manager.rst +++ /dev/null @@ -1,9 +0,0 @@ -Batch Manager -====================== -.. currentmodule:: pina.condition.batch_manager - -.. automodule:: pina._src.condition.batch_manager - -.. autoclass:: pina._src.condition.batch_manager._BatchManager - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/data_manager.rst b/docs/source/_rst/condition/data_manager.rst deleted file mode 100644 index 66e177854..000000000 --- a/docs/source/_rst/condition/data_manager.rst +++ /dev/null @@ -1,9 +0,0 @@ -Data Manager -====================== -.. currentmodule:: pina.condition.data_manager - -.. automodule:: pina._src.condition.data_manager - -.. autoclass:: pina._src.condition.data_manager._DataManager - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/data_manager_interface.rst b/docs/source/_rst/condition/data_manager_interface.rst deleted file mode 100644 index b1adac823..000000000 --- a/docs/source/_rst/condition/data_manager_interface.rst +++ /dev/null @@ -1,9 +0,0 @@ -Data Manager Interface -========================= -.. currentmodule:: pina.condition.data_manager_interface - -.. automodule:: pina._src.condition.data_manager_interface - -.. autoclass:: pina._src.condition.data_manager_interface._DataManagerInterface - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/graph_data_manager.rst b/docs/source/_rst/condition/graph_data_manager.rst deleted file mode 100644 index b8b6ba39e..000000000 --- a/docs/source/_rst/condition/graph_data_manager.rst +++ /dev/null @@ -1,9 +0,0 @@ -Graph Data Manager -====================== -.. currentmodule:: pina.condition.graph_data_manager - -.. automodule:: pina._src.condition.graph_data_manager - -.. autoclass:: pina._src.condition.graph_data_manager._GraphDataManager - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/condition/tensor_data_manager.rst b/docs/source/_rst/condition/tensor_data_manager.rst deleted file mode 100644 index e45e86c8c..000000000 --- a/docs/source/_rst/condition/tensor_data_manager.rst +++ /dev/null @@ -1,9 +0,0 @@ -Tensor Data Manager -====================== -.. currentmodule:: pina.condition.tensor_data_manager - -.. automodule:: pina._src.condition.tensor_data_manager - -.. autoclass:: pina._src.condition.tensor_data_manager._TensorDataManager - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/manager/batch_manager.rst b/docs/source/_rst/data/manager/batch_manager.rst new file mode 100644 index 000000000..5d7c36650 --- /dev/null +++ b/docs/source/_rst/data/manager/batch_manager.rst @@ -0,0 +1,9 @@ +Batch Manager +====================== +.. currentmodule:: pina.data.manager.batch_manager + +.. automodule:: pina._src.data.manager.batch_manager + +.. autoclass:: pina._src.data.manager.batch_manager._BatchManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/manager/data_manager.rst b/docs/source/_rst/data/manager/data_manager.rst new file mode 100644 index 000000000..9b32b8242 --- /dev/null +++ b/docs/source/_rst/data/manager/data_manager.rst @@ -0,0 +1,9 @@ +Data Manager +====================== +.. currentmodule:: pina.data.manager.data_manager + +.. automodule:: pina._src.data.manager.data_manager + +.. autoclass:: pina._src.data.manager.data_manager._DataManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/manager/data_manager_interface.rst b/docs/source/_rst/data/manager/data_manager_interface.rst new file mode 100644 index 000000000..e4a502abf --- /dev/null +++ b/docs/source/_rst/data/manager/data_manager_interface.rst @@ -0,0 +1,9 @@ +Data Manager Interface +========================= +.. currentmodule:: pina.data.manager.data_manager_interface + +.. automodule:: pina._src.data.manager.data_manager_interface + +.. autoclass:: pina._src.data.manager.data_manager_interface._DataManagerInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/manager/graph_data_manager.rst b/docs/source/_rst/data/manager/graph_data_manager.rst new file mode 100644 index 000000000..bbbf23a52 --- /dev/null +++ b/docs/source/_rst/data/manager/graph_data_manager.rst @@ -0,0 +1,9 @@ +Graph Data Manager +====================== +.. currentmodule:: pina.data.manager.graph_data_manager + +.. automodule:: pina._src.data.manager.graph_data_manager + +.. autoclass:: pina._src.data.manager.graph_data_manager._GraphDataManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/manager/tensor_data_manager.rst b/docs/source/_rst/data/manager/tensor_data_manager.rst new file mode 100644 index 000000000..f8bb06028 --- /dev/null +++ b/docs/source/_rst/data/manager/tensor_data_manager.rst @@ -0,0 +1,9 @@ +Tensor Data Manager +====================== +.. currentmodule:: pina.data.manager.tensor_data_manager + +.. automodule:: pina._src.data.manager.tensor_data_manager + +.. autoclass:: pina._src.data.manager.tensor_data_manager._TensorDataManager + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/_src/condition/data_condition.py b/pina/_src/condition/data_condition.py index da34f838b..28a32aa0e 100644 --- a/pina/_src/condition/data_condition.py +++ b/pina/_src/condition/data_condition.py @@ -5,7 +5,7 @@ from pina._src.condition.base_condition import BaseCondition from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph -from pina._src.condition.data_manager import _DataManager +from pina._src.data.manager.data_manager import _DataManager from pina._src.core.utils import check_consistency diff --git a/pina/_src/condition/input_equation_condition.py b/pina/_src/condition/input_equation_condition.py index 26958fb08..40f1cd5df 100644 --- a/pina/_src/condition/input_equation_condition.py +++ b/pina/_src/condition/input_equation_condition.py @@ -4,7 +4,7 @@ from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph from pina._src.equation.base_equation import BaseEquation -from pina._src.condition.data_manager import _DataManager +from pina._src.data.manager.data_manager import _DataManager from pina._src.core.utils import check_consistency diff --git a/pina/_src/condition/input_target_condition.py b/pina/_src/condition/input_target_condition.py index 4b641e528..74841b961 100644 --- a/pina/_src/condition/input_target_condition.py +++ b/pina/_src/condition/input_target_condition.py @@ -5,7 +5,7 @@ from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph from pina._src.condition.base_condition import BaseCondition -from pina._src.condition.data_manager import _DataManager +from pina._src.data.manager.data_manager import _DataManager from pina._src.core.utils import check_consistency diff --git a/pina/_src/condition/batch_manager.py b/pina/_src/data/manager/batch_manager.py similarity index 100% rename from pina/_src/condition/batch_manager.py rename to pina/_src/data/manager/batch_manager.py diff --git a/pina/_src/condition/data_manager.py b/pina/_src/data/manager/data_manager.py similarity index 73% rename from pina/_src/condition/data_manager.py rename to pina/_src/data/manager/data_manager.py index 723a4f059..3fd976d1d 100644 --- a/pina/_src/condition/data_manager.py +++ b/pina/_src/data/manager/data_manager.py @@ -3,8 +3,8 @@ import torch from pina._src.core.label_tensor import LabelTensor from pina._src.equation.base_equation import BaseEquation -from pina._src.condition.graph_data_manager import _GraphDataManager -from pina._src.condition.tensor_data_manager import _TensorDataManager +from pina._src.data.manager.graph_data_manager import _GraphDataManager +from pina._src.data.manager.tensor_data_manager import _TensorDataManager class _DataManager: @@ -12,9 +12,9 @@ class _DataManager: Factory class for data manager implementations. This class dispatches object creation to either - :class:`~pina.condition.tensor_data_manager._TensorDataManager` or - :class:`~pina.condition.graph_data_manager._GraphDataManager` depending on - the types of the provided keyword arguments. + :class:`~pina.data.manager.tensor_data_manager._TensorDataManager` or + :class:`~pina.data.manager.graph_data_manager._GraphDataManager` depending + on the types of the provided keyword arguments. """ def __new__(cls, **kwargs): @@ -25,9 +25,9 @@ def __new__(cls, **kwargs): If all values in ``kwargs`` are instances of :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, or :class:`~pina.equation.base_equation.BaseEquation`, an instance of - :class:`~pina.condition.tensor_data_manager._TensorDataManager` is + :class:`~pina.data.manager.tensor_data_manager._TensorDataManager` is created. Otherwise, an instance of - :class:`~pina.condition.graph_data_manager._GraphDataManager` is + :class:`~pina.data.manager.graph_data_manager._GraphDataManager` is created. :param dict kwargs: The keyword arguments for the data manager. diff --git a/pina/_src/condition/data_manager_interface.py b/pina/_src/data/manager/data_manager_interface.py similarity index 96% rename from pina/_src/condition/data_manager_interface.py rename to pina/_src/data/manager/data_manager_interface.py index 2e51dd3a1..41b841e39 100644 --- a/pina/_src/condition/data_manager_interface.py +++ b/pina/_src/data/manager/data_manager_interface.py @@ -1,4 +1,4 @@ -"""Module for the Tensor-Data Manager interface.""" +"""Module for the Data Manager interface.""" from abc import ABCMeta, abstractmethod diff --git a/pina/_src/condition/graph_data_manager.py b/pina/_src/data/manager/graph_data_manager.py similarity index 98% rename from pina/_src/condition/graph_data_manager.py rename to pina/_src/data/manager/graph_data_manager.py index b05ac5c7a..660c75f83 100644 --- a/pina/_src/condition/graph_data_manager.py +++ b/pina/_src/data/manager/graph_data_manager.py @@ -5,8 +5,8 @@ from torch_geometric.data.batch import Batch from pina._src.core.label_tensor import LabelTensor from pina._src.core.graph import Graph, LabelBatch -from pina._src.condition.batch_manager import _BatchManager -from pina._src.condition.data_manager_interface import _DataManagerInterface +from pina._src.data.manager.batch_manager import _BatchManager +from pina._src.data.manager.data_manager_interface import _DataManagerInterface class _GraphDataManager(_DataManagerInterface): diff --git a/pina/_src/condition/tensor_data_manager.py b/pina/_src/data/manager/tensor_data_manager.py similarity index 95% rename from pina/_src/condition/tensor_data_manager.py rename to pina/_src/data/manager/tensor_data_manager.py index a1ec0b023..2e530c40f 100644 --- a/pina/_src/condition/tensor_data_manager.py +++ b/pina/_src/data/manager/tensor_data_manager.py @@ -2,8 +2,8 @@ import torch from pina._src.core.label_tensor import LabelTensor -from pina._src.condition.batch_manager import _BatchManager -from pina._src.condition.data_manager_interface import _DataManagerInterface +from pina._src.data.manager.batch_manager import _BatchManager +from pina._src.data.manager.data_manager_interface import _DataManagerInterface class _TensorDataManager(_DataManagerInterface): diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 64b72901f..460ce5d32 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -14,11 +14,6 @@ "InputTargetCondition", "InputEquationCondition", "DataCondition", - "_DataManagerInterface", - "_DataManager", - "_GraphDataManager", - "_TensorDataManager", - "_BatchManager", ] from pina._src.condition.condition_interface import ConditionInterface @@ -30,8 +25,3 @@ from pina._src.condition.input_target_condition import InputTargetCondition from pina._src.condition.input_equation_condition import InputEquationCondition from pina._src.condition.data_condition import DataCondition -from pina._src.condition.batch_manager import _BatchManager -from pina._src.condition.data_manager_interface import _DataManagerInterface -from pina._src.condition.data_manager import _DataManager -from pina._src.condition.tensor_data_manager import _TensorDataManager -from pina._src.condition.graph_data_manager import _GraphDataManager diff --git a/pina/data/manager.py b/pina/data/manager.py new file mode 100644 index 000000000..1441cee12 --- /dev/null +++ b/pina/data/manager.py @@ -0,0 +1,15 @@ +"""Module for condition data management.""" + +__all__ = [ + "_BatchManager", + "_DataManagerInterface", + "_DataManager", + "_TensorDataManager", + "_GraphDataManager", +] + +from pina._src.data.manager.batch_manager import _BatchManager +from pina._src.data.manager.data_manager import _DataManager +from pina._src.data.manager.tensor_data_manager import _TensorDataManager +from pina._src.data.manager.graph_data_manager import _GraphDataManager +from pina._src.data.manager.data_manager_interface import _DataManagerInterface diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index 5aa6abaae..5676e9f63 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -2,11 +2,11 @@ import pytest from pina.graph import RadiusGraph, Graph from pina import LabelTensor, Condition -from pina.condition import ( - DataCondition, - _BatchManager, +from pina.condition import DataCondition +from pina.data.manager import ( _TensorDataManager, _GraphDataManager, + _BatchManager, ) diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index a6366e86a..1d3b8e08a 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -3,13 +3,14 @@ from pina.equation import Equation from pina import LabelTensor, Condition from pina.graph import RadiusGraph, Graph -from pina.condition import ( - InputEquationCondition, +from pina.condition import InputEquationCondition +from pina.data.manager import ( _TensorDataManager, _GraphDataManager, _BatchManager, ) + # Generate input and equation data for testing - tensor case input_tensor = LabelTensor(torch.rand((10, 2)), ["x", "y"]) equation_tensor = Equation(lambda pts: pts["x"] ** 2 + pts["y"] ** 2 - 1) diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index 8352ee1c3..903c21b70 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -2,11 +2,11 @@ import pytest from pina.graph import RadiusGraph, Graph from pina import LabelTensor, Condition -from pina.condition import ( - InputTargetCondition, - _BatchManager, +from pina.condition import InputTargetCondition +from pina.data.manager import ( _TensorDataManager, _GraphDataManager, + _BatchManager, ) diff --git a/tests/test_datamodule.py b/tests/test_data/test_data_module.py similarity index 100% rename from tests/test_datamodule.py rename to tests/test_data/test_data_module.py diff --git a/tests/test_data/test_graph_data_manager.py b/tests/test_data/test_graph_data_manager.py new file mode 100644 index 000000000..dd0bb47d8 --- /dev/null +++ b/tests/test_data/test_graph_data_manager.py @@ -0,0 +1,115 @@ +import torch +import pytest +from pina import LabelTensor +from pina.graph import Graph +from pina.data.manager import _DataManager, _GraphDataManager, _BatchManager + + +# Define data for testing +standard_graph = [ + Graph( + x=torch.rand((10, 3)), + pos=torch.rand((10, 2)), + edge_index=torch.randint(0, 10, (2, 20)), + ) + for _ in range(3) +] +label_graph = [ + Graph( + x=LabelTensor(torch.rand((10, 3)), labels=["a", "b", "c"]), + pos=LabelTensor(torch.rand((10, 2)), labels=["x", "y"]), + edge_index=torch.randint(0, 10, (2, 20)), + ) + for _ in range(3) +] +target_ = torch.rand((3, 10, 1)) +label_target = LabelTensor(target_, labels=["target"]) + + +@pytest.mark.parametrize("case", ["standard", "labeled"]) +def test_constructor(case): + + # Define data for testing + if case == "standard": + graph = standard_graph + target = target_ + exp_type = torch.Tensor + else: + graph = label_graph + target = label_target + exp_type = LabelTensor + + # Create data manager + data_manager = _DataManager(graph=graph, target=target) + + # Check that the data manager is an instance of _GraphDataManager + assert isinstance(data_manager, _GraphDataManager) + + # Check that the attributes are set correctly + assert hasattr(data_manager, "graph_key") + assert hasattr(data_manager, "graph") + assert hasattr(data_manager, "target") + assert data_manager.graph_key == "graph" + + # Check that the graph length is correct + assert len(data_manager.graph) == len(graph) + + # Check that the attributes have the correct types + assert isinstance(data_manager.target, exp_type) + assert isinstance(data_manager.graph, list) + for g in data_manager.graph: + assert isinstance(g, Graph) + + # Check that the values of the attributes are correct + assert torch.equal(data_manager.target, target) + for i in range(len(graph)): + assert torch.equal(data_manager.graph[i].x, graph[i].x) + assert torch.equal(data_manager.graph[i].pos, graph[i].pos) + assert torch.equal( + data_manager.graph[i].edge_index, graph[i].edge_index + ) + assert torch.equal(data_manager.graph[i].target, graph[i].target) + + +@pytest.mark.parametrize("case", ["standard", "labeled"]) +def test_create_batch(case): + + # Define data for testing + if case == "standard": + graph = standard_graph + target = target_ + exp_type = torch.Tensor + else: + graph = label_graph + target = label_target + exp_type = LabelTensor + + # Create data manager + data_manager = _DataManager(graph=graph, target=target) + + # Batch over indices + idx = [0, 2] + batch = _GraphDataManager.create_batch([data_manager[idx] for idx in idx]) + + # Check that the batch is an instance of _BatchManager + assert isinstance(batch, _BatchManager) + + # Check that the attributes are set correctly + assert hasattr(batch, "graph") + assert hasattr(batch, "target") + + # Check that the graph length is correct + assert batch.graph.num_graphs == len(idx) + + # Check that the attributes have the correct types + assert isinstance(batch.target, exp_type) + assert isinstance(batch.graph, Graph) + + # Check that the values of the attributes are correct + assert torch.equal(batch.target, torch.cat([target[i] for i in idx], dim=0)) + assert torch.equal( + batch.graph.x, torch.cat([graph[i].x for i in idx], dim=0) + ) + assert torch.equal( + batch.graph.pos, torch.cat([graph[i].pos for i in idx], dim=0) + ) diff --git a/tests/test_data/test_tensor_data_manager.py b/tests/test_data/test_tensor_data_manager.py new file mode 100644 index 000000000..7624e5971 --- /dev/null +++ b/tests/test_data/test_tensor_data_manager.py @@ -0,0 +1,54 @@ +import torch +from pina import LabelTensor +from pina.data.manager import _DataManager, _TensorDataManager, _BatchManager + + +# Define data for testing +standard_tensor = torch.rand((10, 3)) +label_tensor = LabelTensor(standard_tensor, labels=["a", "b", "c"]) + + +def test_constructor(): + + # Create data manager + data_manager = _DataManager(standard=standard_tensor, labeled=label_tensor) + + # Check that the data manager is an instance of _TensorDataManager + assert isinstance(data_manager, _TensorDataManager) + + # Check that the attributes are set correctly + assert hasattr(data_manager, "standard") + assert hasattr(data_manager, "labeled") + + # Check that the attributes have the correct types + assert isinstance(data_manager.standard, torch.Tensor) + assert isinstance(data_manager.labeled, LabelTensor) + + # Check that the values of the attributes are correct + assert torch.equal(data_manager.standard, standard_tensor) + assert torch.equal(data_manager.labeled, label_tensor) + + +def test_create_batch(): + + # Create data manager + data_manager = _DataManager(standard=standard_tensor, labeled=label_tensor) + + # Batch over indices + idx = [0, 2] + batch = _TensorDataManager.create_batch([data_manager[idx] for idx in idx]) + + # Check that the batch is an instance of _BatchManager + assert isinstance(batch, _BatchManager) + + # Check that the attributes are set correctly + assert hasattr(batch, "standard") + assert hasattr(batch, "labeled") + + # Check that the attributes have the correct types + assert isinstance(batch.standard, torch.Tensor) + assert isinstance(batch.labeled, LabelTensor) + + # Check that the values of the attributes are correct + assert torch.equal(batch.standard, standard_tensor[idx]) + assert torch.equal(batch.labeled, label_tensor[idx]) diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py deleted file mode 100644 index 9bab62b57..000000000 --- a/tests/test_data_manager.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -from pina._src.condition.data_manager import ( - _DataManager, - _TensorDataManager, - _GraphDataManager, -) -from pina.graph import Graph -from pina.equation import Equation - - -def test_tensor_data_manager_init(): - pippo = torch.rand((10, 5)) - pluto = torch.rand((10, 7)) - paperino = torch.rand((10, 11)) - data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) - assert isinstance(data_manager, _TensorDataManager) - assert hasattr(data_manager, "pippo") - assert hasattr(data_manager, "pluto") - assert hasattr(data_manager, "paperino") - assert torch.equal(data_manager.pippo, pippo) - assert torch.equal(data_manager.pluto, pluto) - assert torch.equal(data_manager.paperino, paperino) - - paperino = Equation(lambda x: x**2) - data_manager3 = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) - assert isinstance(data_manager3, _TensorDataManager) - assert hasattr(data_manager3, "pippo") - assert hasattr(data_manager3, "pluto") - assert hasattr(data_manager3, "paperino") - assert torch.equal(data_manager3.pippo, pippo) - assert torch.equal(data_manager3.pluto, pluto) - assert isinstance(data_manager3.paperino, Equation) - - -def test_graph_data_manager_init(): - x = [torch.rand((10, 5)) for _ in range(3)] - pos = [torch.rand((10, 3)) for _ in range(3)] - edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] - graph = [ - Graph(x=x_, pos=pos_, edge_index=edge_index_) - for x_, pos_, edge_index_ in zip(x, pos, edge_index) - ] - target = torch.rand((3, 10, 1)) - data_manager = _DataManager(graph=graph, target=target) - assert hasattr(data_manager, "graph_key") - assert data_manager.graph_key == "graph" - assert hasattr(data_manager, "graph") - assert len(data_manager.data) == 3 - for i in range(3): - g = data_manager.graph[i] - assert torch.equal(g.x, x[i]) - assert torch.equal(g.pos, pos[i]) - assert torch.equal(g.edge_index, edge_index[i]) - assert torch.equal(g.target, target[i]) - - -def test_graph_data_manager_getattribute(): - x = [torch.rand((10, 5)) for _ in range(3)] - pos = [torch.rand((10, 3)) for _ in range(3)] - edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] - graph = [ - Graph(x=x_, pos=pos_, edge_index=edge_index_) - for x_, pos_, edge_index_ in zip(x, pos, edge_index) - ] - target = torch.rand((3, 10, 1)) - data_manager = _DataManager(graph=graph, target=target) - target_retrieved = data_manager.target - assert torch.equal(target_retrieved, target) - - -def test_graph_data_manager_getitem(): - x = [torch.rand((10, 5)) for _ in range(3)] - pos = [torch.rand((10, 3)) for _ in range(3)] - edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] - graph = [ - Graph(x=x_, pos=pos_, edge_index=edge_index_) - for x_, pos_, edge_index_ in zip(x, pos, edge_index) - ] - target = torch.rand((3, 10, 1)) - data_manager = _DataManager(graph=graph, target=target) - item = data_manager[1] - assert isinstance(item, _DataManager) - assert hasattr(item, "graph_key") - assert item.graph_key == "graph" - assert hasattr(item, "graph") - assert torch.equal(item.graph.x, x[1]) - assert torch.equal(item.graph.pos, pos[1]) - assert torch.equal(item.graph.edge_index, edge_index[1]) - assert torch.equal(item.target, target[1].unsqueeze(0)) - - -def test_graph_data_create_batch(): - x = [torch.rand((10, 5)) for _ in range(3)] - pos = [torch.rand((10, 3)) for _ in range(3)] - edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] - graph = [ - Graph(x=x_, pos=pos_, edge_index=edge_index_) - for x_, pos_, edge_index_ in zip(x, pos, edge_index) - ] - target = torch.rand((3, 10, 1)) - data_manager = _DataManager(graph=graph, target=target) - item1 = data_manager[0] - item2 = data_manager[1] - batch_data = _GraphDataManager.create_batch([item1, item2]) - assert hasattr(batch_data, "graph") - assert hasattr(batch_data, "target") - batched_graphs = batch_data.graph - batched_target = batch_data.target - assert batched_graphs.num_graphs == 2 - assert batched_target.shape == (20, 1) - assert torch.equal(batched_target, torch.cat([target[0], target[1]], dim=0)) - mps_data = batch_data.to("mps") - assert mps_data.graph.num_graphs == 2 - assert torch.equal(mps_data.target, batched_target.to("mps")) - assert torch.equal(mps_data.graph.x, batched_graphs.x.to("mps")) - - -def test_tensor_data_create_batch(): - pippo = torch.rand((10, 5)) - pluto = torch.rand((10, 7)) - paperino = torch.rand((10, 11)) - data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) - item1 = data_manager[0] - item2 = data_manager[1] - batch_data = _TensorDataManager.create_batch([item1, item2]) - assert hasattr(batch_data, "pippo") - assert hasattr(batch_data, "pluto") - assert hasattr(batch_data, "paperino") - assert torch.equal( - batch_data.pippo, torch.stack([pippo[0], pippo[1]], dim=0) - ) - assert torch.equal( - batch_data.pluto, torch.stack([pluto[0], pluto[1]], dim=0) - ) - assert torch.equal( - batch_data.paperino, torch.stack([paperino[0], paperino[1]], dim=0) - )