Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ Optimizers and Schedulers
.. toctree::
:titlesonly:

Optimizer <optim/optimizer_interface.rst>
Scheduler <optim/scheduler_interface.rst>
TorchOptimizer <optim/torch_optimizer.rst>
TorchScheduler <optim/torch_scheduler.rst>
Optimizer Interface <optim/optimizer_interface.rst>
Scheduler Interface <optim/scheduler_interface.rst>
Torch Optimizer <optim/torch_optimizer.rst>
Torch Scheduler <optim/torch_scheduler.rst>


Adaptive Functions
Expand Down Expand Up @@ -297,11 +297,13 @@ Callbacks

Switch Optimizer <callback/optim/switch_optimizer.rst>
Switch Scheduler <callback/optim/switch_scheduler.rst>
Normalizer Data <callback/processing/normalizer_data_callback.rst>
PINA Progress Bar <callback/processing/pina_progress_bar.rst>
Metric Tracker <callback/processing/metric_tracker.rst>
Refinement Interface <callback/refinement/refinement_interface.rst>
Base Refinement <callback/refinement/base_refinement.rst>
R3 Refinement <callback/refinement/r3_refinement.rst>
Data Normalizer <callback/processing/data_normalizer.rst>
Metric Tracker <callback/processing/metric_tracker.rst>
PINA Progress Bar <callback/processing/pina_progress_bar.rst>


Losses
---------
Expand Down
9 changes: 9 additions & 0 deletions docs/source/_rst/callback/processing/data_normalizer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Data Normalizer
=======================
.. currentmodule:: pina.callback.processing.data_normalizer

.. automodule:: pina._src.callback.processing.data_normalizer

.. autoclass:: pina._src.callback.processing.data_normalizer.DataNormalizer
:members:
:show-inheritance:
8 changes: 5 additions & 3 deletions docs/source/_rst/callback/processing/metric_tracker.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
Metric Tracker
==================
.. currentmodule:: pina.callback.processing.metric_tracker

.. automodule:: pina._src.callback.processing.metric_tracker
:show-inheritance:
.. autoclass:: MetricTracker

.. autoclass:: pina._src.callback.processing.metric_tracker.MetricTracker
:members:
:show-inheritance:
:show-inheritance:
:noindex:

This file was deleted.

7 changes: 4 additions & 3 deletions docs/source/_rst/callback/processing/pina_progress_bar.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
PINA Progress Bar
==================
.. currentmodule:: pina.callback.processing.pina_progress_bar

.. automodule:: pina._src.callback.processing.pina_progress_bar
:show-inheritance:
.. autoclass:: PINAProgressBar

.. autoclass:: pina._src.callback.processing.pina_progress_bar.PINAProgressBar
:members:
:show-inheritance:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/callback/refinement/base_refinement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Base Refinement
=======================

.. currentmodule:: pina.callback.refinement.base_refinement
.. autoclass:: pina._src.callback.refinement.base_refinement.BaseRefinement
:members:
:show-inheritance:
4 changes: 2 additions & 2 deletions docs/source/_rst/callback/refinement/r3_refinement.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Refinments callbacks
R3 Refinement
=======================

.. currentmodule:: pina.callback
.. currentmodule:: pina.callback.refinement.r3_refinement
.. autoclass:: pina._src.callback.refinement.r3_refinement.R3Refinement
:members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Refinement Interface
=======================

.. currentmodule:: pina.callback
.. currentmodule:: pina.callback.refinement.refinement_interface
.. autoclass:: pina._src.callback.refinement.refinement_interface.RefinementInterface
:members:
:show-inheritance:
6 changes: 3 additions & 3 deletions docs/source/_rst/optim/optimizer_interface.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Optimizer
============
Optimizer Interface
=====================
.. currentmodule:: pina.optim.optimizer_interface

.. autoclass:: pina._src.optim.optimizer_interface.Optimizer
.. autoclass:: pina._src.optim.optimizer_interface.OptimizerInterface
:members:
:show-inheritance:
6 changes: 3 additions & 3 deletions docs/source/_rst/optim/scheduler_interface.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Scheduler
=============
Scheduler Interface
=====================
.. currentmodule:: pina.optim.scheduler_interface

.. autoclass:: pina._src.optim.scheduler_interface.Scheduler
.. autoclass:: pina._src.optim.scheduler_interface.SchedulerInterface
:members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/source/_rst/optim/torch_optimizer.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TorchOptimizer
Torch Optimizer
===============
.. currentmodule:: pina.optim.torch_optimizer

Expand Down
2 changes: 1 addition & 1 deletion docs/source/_rst/optim/torch_scheduler.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TorchScheduler
Torch Scheduler
===============
.. currentmodule:: pina.optim.torch_scheduler

Expand Down
42 changes: 23 additions & 19 deletions pina/_src/callback/optim/switch_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
"""Module for the SwitchOptimizer callback."""

from lightning.pytorch.callbacks import Callback
from pina._src.optim.torch_optimizer import TorchOptimizer
from pina._src.core.utils import check_consistency
from pina._src.optim.optimizer_interface import OptimizerInterface
from pina._src.core.utils import check_consistency, check_positive_integer


class SwitchOptimizer(Callback):
"""
PINA Implementation of a Lightning Callback to switch optimizer during
training.
Lightning callback for dynamically replacing optimizers during training.

This callback enables switching to one or more new optimizers at a specified
epoch without restarting the training loop. It is particularly useful for
staged optimization strategies (e.g., coarse-to-fine training or optimizer
warm-up phases), where different optimizers are applied sequentially.

At the target epoch, the provided optimizers are hooked to the model
parameters and replace the current optimizers in both the PINA solver and
the Lightning trainer strategy.
"""

def __init__(self, new_optimizers, epoch_switch):
"""
This callback allows switching between different optimizers during
training, enabling the exploration of multiple optimization strategies
without interrupting the training process.
Initialization of the :class:`SwitchOptimizer` class.

:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` instance or a list of them
for multiple model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list
:type new_optimizers: pina.optim.OptimizerInterface | list
:param int epoch_switch: The epoch at which the optimizer switch occurs.
:raises AssertionError: If ``epoch_switch`` is not a positive integer.
:raises ValueError: If any of the provided optimizers are not instances
of :class:`pina.optim.OptimizerInterface`.

Example:
>>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
Expand All @@ -31,19 +40,14 @@ def __init__(self, new_optimizers, epoch_switch):
"""
super().__init__()

# Check if epoch_switch is greater than 1
if epoch_switch < 1:
raise ValueError("epoch_switch must be greater than one.")
# Check consistency
check_positive_integer(epoch_switch, strict=True)
check_consistency(new_optimizers, OptimizerInterface)

# If new_optimizers is not a list, convert it to a list
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]

# Check consistency
check_consistency(epoch_switch, int)
for optimizer in new_optimizers:
check_consistency(optimizer, TorchOptimizer)

# Store the new optimizers and epoch switch
self._new_optimizers = new_optimizers
self._epoch_switch = epoch_switch
Expand All @@ -52,9 +56,9 @@ def on_train_epoch_start(self, trainer, __):
"""
Switch the optimizer at the start of the specified training epoch.

:param lightning.pytorch.Trainer trainer: The trainer object managing
the training process.
:param _: Placeholder argument (not used).
:param Trainer trainer: The trainer object managing the training
process.
:param __: Placeholder argument, not used.
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
Expand Down
34 changes: 16 additions & 18 deletions pina/_src/callback/optim/switch_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
"""Module for the SwitchScheduler callback."""

from lightning.pytorch.callbacks import Callback
from pina._src.optim.torch_scheduler import TorchScheduler
from pina._src.optim.scheduler_interface import SchedulerInterface
from pina._src.core.utils import check_consistency, check_positive_integer


class SwitchScheduler(Callback):
"""
Callback to switch scheduler during training.
Lightning callback for dynamically replacing schedulers during training.

This callback enables switching to new scheduler(s) at a specified epoch
without interrupting the training loop. It is useful for staged training
strategies where different learning rate policies are applied sequentially.
"""

def __init__(self, new_schedulers, epoch_switch):
"""
This callback allows switching between different schedulers during
training, enabling the exploration of multiple optimization strategies
without interrupting the training process.
Initialization of the :class:`SwitchScheduler` class.

:param new_schedulers: The scheduler or list of schedulers to switch to.
Use a single scheduler for single-model solvers, or a list of
schedulers when working with multiple models.
:type new_schedulers: pina.optim.TorchScheduler |
list[pina.optim.TorchScheduler]
:type new_schedulers: SchedulerInterface | list[SchedulerInterface]
:param int epoch_switch: The epoch at which the scheduler switch occurs.
:raises AssertionError: If epoch_switch is less than 1.
:raises ValueError: If each scheduler in ``new_schedulers`` is not an
instance of :class:`pina.optim.TorchScheduler`.
:raises AssertionError: If ``epoch_switch`` is not a positive integer.
:raises ValueError: If any of the provided schedulers are not instances
of :class:`pina.optim.SchedulerInterface`.

Example:
>>> scheduler = TorchScheduler(
Expand All @@ -36,17 +37,14 @@ def __init__(self, new_schedulers, epoch_switch):
"""
super().__init__()

# Check if epoch_switch is greater than 1
check_positive_integer(epoch_switch - 1, strict=True)
# Check consistency
check_positive_integer(epoch_switch, strict=True)
check_consistency(new_schedulers, SchedulerInterface)

# If new_schedulers is not a list, convert it to a list
if not isinstance(new_schedulers, list):
new_schedulers = [new_schedulers]

# Check consistency
for scheduler in new_schedulers:
check_consistency(scheduler, TorchScheduler)

# Store the new schedulers and epoch switch
self._new_schedulers = new_schedulers
self._epoch_switch = epoch_switch
Expand All @@ -55,9 +53,9 @@ def on_train_epoch_start(self, trainer, __):
"""
Switch the scheduler at the start of the specified training epoch.

:param lightning.pytorch.Trainer trainer: The trainer object managing
:param Trainer trainer: The trainer object managing
the training process.
:param __: Placeholder argument (not used).
:param __: Placeholder argument, not used.
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
Expand Down
Loading
Loading