Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,24 @@ Conditions
.. toctree::
:titlesonly:

ConditionInterface <condition/condition_interface.rst>
Condition Interface <condition/condition_interface.rst>
Base Condition <condition/base_condition.rst>
Condition <condition/condition.rst>
DataCondition <condition/data_condition.rst>
DomainEquationCondition <condition/domain_equation_condition.rst>
InputEquationCondition <condition/input_equation_condition.rst>
InputTargetCondition <condition/input_target_condition.rst>
Data Condition <condition/data_condition.rst>
Domain Equation Condition <condition/domain_equation_condition.rst>
Input Equation Condition <condition/input_equation_condition.rst>
Input Target Condition <condition/input_target_condition.rst>

Batch and Data Managers
--------------------------
.. toctree::
:titlesonly:

Batch Manager <data/manager/batch_manager.rst>
Data Manager Interface <data/manager/data_manager_interface.rst>
Data Manager <data/manager/data_manager.rst>
Graph Data Manager <data/manager/graph_data_manager.rst>
Tensor Data Manager <data/manager/tensor_data_manager.rst>

Solvers
--------------
Expand All @@ -68,18 +80,18 @@ Solvers
SolverInterface <solver/solver_interface.rst>
SingleSolverInterface <solver/single_solver_interface.rst>
MultiSolverInterface <solver/multi_solver_interface.rst>
SupervisedSolverInterface <solver/supervised_solver/supervised_solver_interface>
DeepEnsembleSolverInterface <solver/ensemble_solver/ensemble_solver_interface>
SupervisedSolverInterface <solver/supervised_solver/supervised_solver_interface.rst>
DeepEnsembleSolverInterface <solver/ensemble_solver/ensemble_solver_interface.rst>
PINNInterface <solver/physics_informed_solver/pinn_interface.rst>
PINN <solver/physics_informed_solver/pinn.rst>
GradientPINN <solver/physics_informed_solver/gradient_pinn.rst>
CausalPINN <solver/physics_informed_solver/causal_pinn.rst>
CompetitivePINN <solver/physics_informed_solver/competitive_pinn.rst>
SelfAdaptivePINN <solver/physics_informed_solver/self_adaptive_pinn.rst>
RBAPINN <solver/physics_informed_solver/rba_pinn.rst>
DeepEnsemblePINN <solver/ensemble_solver/ensemble_pinn>
DeepEnsemblePINN <solver/ensemble_solver/ensemble_pinn.rst>
SupervisedSolver <solver/supervised_solver/supervised.rst>
DeepEnsembleSupervisedSolver <solver/ensemble_solver/ensemble_supervised>
DeepEnsembleSupervisedSolver <solver/ensemble_solver/ensemble_supervised.rst>
ReducedOrderModelSolver <solver/supervised_solver/reduced_order_model.rst>
GAROM <solver/garom.rst>
AutoregressiveSolverInterface <solver/autoregressive_solver/autoregressive_solver_interface.rst>
Expand Down Expand Up @@ -203,7 +215,7 @@ Equations and Differential Operators
Differential Operators <operator.rst>


Equations Zoo
Equation Zoo
---------------------------------------

.. toctree::
Expand Down Expand Up @@ -234,7 +246,7 @@ Problems
SpatialProblem <problem/spatial_problem.rst>
TimeDependentProblem <problem/time_dependent_problem.rst>

Problems Zoo
Problem Zoo
--------------

.. toctree::
Expand Down
9 changes: 9 additions & 0 deletions docs/source/_rst/condition/base_condition.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Base Condition
================
.. currentmodule:: pina.condition.base_condition

.. automodule:: pina._src.condition.base_condition

.. autoclass:: pina._src.condition.base_condition.BaseCondition
:members:
:show-inheritance:
6 changes: 4 additions & 2 deletions docs/source/_rst/condition/condition.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
Conditions
Condition
=============
.. currentmodule:: pina.condition.condition

.. automodule:: pina._src.condition.condition

.. autoclass:: pina._src.condition.condition.Condition
:members:
:show-inheritance:
:show-inheritance:
4 changes: 3 additions & 1 deletion docs/source/_rst/condition/condition_interface.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
ConditionInterface
Condition Interface
======================
.. currentmodule:: pina.condition.condition_interface

.. automodule:: pina._src.condition.condition_interface

.. autoclass:: pina._src.condition.condition_interface.ConditionInterface
:members:
:show-inheritance:
4 changes: 3 additions & 1 deletion docs/source/_rst/condition/data_condition.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
Data Conditions
Data Condition
==================
.. currentmodule:: pina.condition.data_condition

.. automodule:: pina._src.condition.data_condition

.. autoclass:: pina._src.condition.data_condition.DataCondition
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/_rst/condition/domain_equation_condition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ Domain Equation Condition
===========================
.. currentmodule:: pina.condition.domain_equation_condition

.. automodule:: pina._src.condition.domain_equation_condition

.. autoclass:: pina._src.condition.domain_equation_condition.DomainEquationCondition
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/_rst/condition/input_equation_condition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ Input Equation Condition
===========================
.. currentmodule:: pina.condition.input_equation_condition

.. automodule:: pina._src.condition.input_equation_condition

.. autoclass:: pina._src.condition.input_equation_condition.InputEquationCondition
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/_rst/condition/input_target_condition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ Input Target Condition
===========================
.. currentmodule:: pina.condition.input_target_condition

.. automodule:: pina._src.condition.input_target_condition

.. autoclass:: pina._src.condition.input_target_condition.InputTargetCondition
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/data/manager/batch_manager.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Batch Manager
======================
.. currentmodule:: pina.data.manager.batch_manager

.. automodule:: pina._src.data.manager.batch_manager

.. autoclass:: pina._src.data.manager.batch_manager._BatchManager
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/data/manager/data_manager.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Data Manager
======================
.. currentmodule:: pina.data.manager.data_manager

.. automodule:: pina._src.data.manager.data_manager

.. autoclass:: pina._src.data.manager.data_manager._DataManager
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/data/manager/data_manager_interface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Data Manager Interface
=========================
.. currentmodule:: pina.data.manager.data_manager_interface

.. automodule:: pina._src.data.manager.data_manager_interface

.. autoclass:: pina._src.data.manager.data_manager_interface._DataManagerInterface
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/data/manager/graph_data_manager.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Graph Data Manager
======================
.. currentmodule:: pina.data.manager.graph_data_manager

.. automodule:: pina._src.data.manager.graph_data_manager

.. autoclass:: pina._src.data.manager.graph_data_manager._GraphDataManager
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/data/manager/tensor_data_manager.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Tensor Data Manager
======================
.. currentmodule:: pina.data.manager.tensor_data_manager

.. automodule:: pina._src.data.manager.tensor_data_manager

.. autoclass:: pina._src.data.manager.tensor_data_manager._TensorDataManager
:members:
:show-inheritance:
153 changes: 153 additions & 0 deletions pina/_src/condition/base_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Module for the Base Condition class."""

from functools import partial
import torch
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from pina._src.condition.condition_interface import ConditionInterface
from pina._src.core.graph import LabelBatch
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.utils import check_consistency
from pina._src.data.dummy_dataloader import DummyDataloader
from pina._src.problem.problem_interface import ProblemInterface


class BaseCondition(ConditionInterface):
"""
Base class for all conditions, implementing common functionality.

All specific condition types should inherit from this class and implement
the abstract methods of
:class:`~pina.condition.condition_interface.ConditionInterface`.

This class is not meant to be instantiated directly.
"""

# Available collate functions for automatic batching
collate_fn_dict = {
"tensor": torch.stack,
"label_tensor": LabelTensor.stack,
"graph": LabelBatch.from_data_list,
"data": Batch.from_data_list,
}

def __init__(self, **kwargs):
"""
Initialization of the :class:`BaseCondition` class.

:param dict kwargs: The keyword arguments representing the data to be
stored in the condition.
"""
super().__init__()
self.data = self.store_data(**kwargs)
self.has_custom_dataloader_fn = False

def __len__(self):
"""
Return the number of data points in the condition.

:return: The number of data points.
:rtype: int
"""
return len(self.data)

def __getitem__(self, idx):
"""
Return the data point at the specified index.

:param int idx: The index of the data point to retrieve.
:return: The data point at the specified index.
:rtype: Any
"""
return self.data[idx]

def create_dataloader(
self, dataset, batch_size, automatic_batching, **kwargs
):
"""
Create the DataLoader for the condition.

:param Dataset dataset: The dataset for the DataLoader.
:param int batch_size: The batch size for the DataLoader.
:param bool automatic_batching: Whether to use automatic batching.
:param dict kwargs: Additional keyword arguments for the DataLoader.
:return: The DataLoader for the condition.
:rtype: torch.utils.data.DataLoader
"""
# If batching the entire dataset, return a DummyDataloader
if batch_size == len(dataset):
return DummyDataloader(dataset)

# Otherwise, return a regular DataLoader with the appropriate collate
return DataLoader(
dataset=dataset,
collate_fn=(
partial(self.collate_fn, condition=self)
if not automatic_batching
else self.automatic_batching_collate_fn
),
batch_size=batch_size,
**kwargs,
)

def switch_dataloader_fn(self, create_dataloader_fn):
"""
Switch the dataloader function for the condition.

:param Callable create_dataloader_fn: The new dataloader function to use
for the condition.
:return: The new dataloader function for the condition.
:rtype: Callable
"""
self.has_custom_dataloader_fn = True
self.create_dataloader = create_dataloader_fn

@classmethod
def automatic_batching_collate_fn(cls, batch):
"""
Collate function for automatic batching to be used in the DataLoader.

:param list batch: A list of items from the dataset.
:return: A collated batch.
:rtype: dict
"""
# If the batch is empty, return an empty dictionary
if not batch:
return {}

# Otherwise, collate the batch using the appropriate collate function
instance_class = batch[0].__class__
return instance_class.create_batch(batch)

@staticmethod
def collate_fn(batch, condition):
"""
Collate function for custom batching to be used in the DataLoader.

:param list batch: A list of items from the dataset.
:param BaseCondition condition: The condition instance.
:return: A collated batch.
:rtype: dict
"""
return condition.data[batch].to_batch()

@property
def problem(self):
"""
The problem associated with this condition.

:return: The problem associated with this condition.
:rtype: BaseProblem
"""
return self._problem

@problem.setter
def problem(self, value):
"""
Set the problem associated with this condition.

:param BaseProblem value: The problem to associate with this condition.
:raises ValueError: If the problem is not an instance of BaseProblem.
"""
check_consistency(value, ProblemInterface)
self._problem = value
Loading
Loading