medkit.training.trainer
=======================

.. py:module:: medkit.training.trainer


Classes
-------

.. autoapisummary::

   medkit.training.trainer.Trainer


Module Contents
---------------

.. py:class:: Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)

   
   Class faciltating training and evaluation of PyTorch models to generate annotations.


   :Parameters:

       **component:**
           The component to train, the component must implement the `TrainableComponent` protocol.

       **config:**
           A `TrainerConfig` with the parameters for training, the parameter `output_dir` define the
           path of the checkpoints

       **train_data:**
           The data to use for training. This should be a corpus of medkit objects. The data could be,
           for instance, a `torch.utils.data.Dataset` that returns medkit objects for training.

       **eval_data:**
           The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects.
           The data can be a `torch.utils.data.Dataset` that returns medkit objects for evaluation.

       **metrics_computer:**
           Optional `MetricsComputer` object that will be used to compute custom metrics during eval.
           By default, only evaluation metrics will be computed, `do_metrics_in_training` in `config` allows
           metrics in training.

       **lr_scheduler_builder:**
           Optional function that build a `lr_scheduler` to adjust the learning rate after an epoch. Must take
           an Optimizer and return a `lr_scheduler`. If not provided, the learning rate does not change during
           training.

       **callback:**
           Optional callback to customize training.














   ..
       !! processed by numpydoc !!

   .. py:attribute:: output_dir


   .. py:attribute:: component


   .. py:attribute:: batch_size


   .. py:attribute:: dataloader_drop_last
      :value: False



   .. py:attribute:: dataloader_nb_workers


   .. py:attribute:: dataloader_pin_memory
      :value: False



   .. py:attribute:: device


   .. py:attribute:: train_dataloader


   .. py:attribute:: eval_dataloader


   .. py:attribute:: nb_training_epochs


   .. py:attribute:: config


   .. py:attribute:: optimizer


   .. py:attribute:: lr_scheduler
      :value: None



   .. py:attribute:: metrics_computer
      :value: None



   .. py:attribute:: callback
      :value: None



   .. py:method:: get_dataloader(data: dict, shuffle: bool) -> torch.utils.data.DataLoader

      
      Return a DataLoader with transformations defined in the component to train.


      :Parameters:

          **data** : dict
              Training data

          **shuffle: bool**
              Whether to use sequential or shuffled sampling



      :Returns:

          torch.utils.data.DataLoader
              The corresponding instance of a DataLoader











      ..
          !! processed by numpydoc !!


   .. py:method:: training_epoch() -> dict[str, float]

      
      Perform an epoch using the training data.

      When the config enabled metrics in training ('do_metrics_in_training' is True),
      the additional metrics are prepared per batch.




      :Returns:

          dict of str to float
              A dictionary containing the training metrics











      ..
          !! processed by numpydoc !!


   .. py:method:: evaluation_epoch(eval_dataloader) -> dict[str, float]

      
      Perform an epoch using the evaluation data.

      The additional metrics are prepared per batch.

      :Parameters:

          **eval_dataloader** : torch.utils.data.DataLoader
              The evaluation dataset as a PyTorch DataLoader



      :Returns:

          dict of str to float
              A dictionary containing the evaluation metrics











      ..
          !! processed by numpydoc !!


   .. py:method:: make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) -> tuple[medkit.training.utils.BatchData, torch.Tensor]

      
      Run forward safely, same device as the component.
















      ..
          !! processed by numpydoc !!


   .. py:method:: update_learning_rate(eval_metrics: dict[str, float]) -> None

      
      Call the learning rate scheduler if defined.
















      ..
          !! processed by numpydoc !!


   .. py:method:: train() -> list[dict]

      
      Call the training and evaluation loop.





      :Returns:

          list of dict of str to float
              The list of computed metrics per epoch











      ..
          !! processed by numpydoc !!


   .. py:method:: save(epoch: int) -> str

      
      Save a checkpoint.

      Checkpoints include trainer configuration, model weights, optimizer and scheduler.

      :Parameters:

          **epoch** : int
              Epoch corresponding of the current training state
              (will be included in the checkpoint name)



      :Returns:

          str
              Path to the saved checkpoint











      ..
          !! processed by numpydoc !!


