medkit.training
===============

.. py:module:: medkit.training


Submodules
----------

.. toctree::
   :maxdepth: 1

   /reference/api/medkit/training/callbacks/index
   /reference/api/medkit/training/trainable_component/index
   /reference/api/medkit/training/trainer/index
   /reference/api/medkit/training/trainer_config/index
   /reference/api/medkit/training/utils/index


Classes
-------

.. autoapisummary::

   medkit.training.DefaultPrinterCallback
   medkit.training.TrainerCallback
   medkit.training.TrainableComponent
   medkit.training.Trainer
   medkit.training.TrainerConfig
   medkit.training.BatchData
   medkit.training.MetricsComputer


Package Contents
----------------

.. py:class:: DefaultPrinterCallback

   Bases: :py:obj:`TrainerCallback`


   
   Default implementation of :class:`~.training.TrainerCallback`.
















   ..
       !! processed by numpydoc !!

   .. py:method:: on_train_begin(config)

      
      Event called at the beginning of training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_epoch_end(metrics, epoch, epoch_duration)

      
      Event called at the end of an epoch.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_train_end()

      
      Event called at the end of training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_save(checkpoint_dir)

      
      Event called on saving a checkpoint.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_step_begin(step_idx: int, nb_batches: int, phase: str)

      
      Event called at the beginning of a step in training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_step_end(step_idx: int, nb_batches: int, phase: str)

      
      Event called at the end of a step in training.
















      ..
          !! processed by numpydoc !!


.. py:class:: TrainerCallback

   
   A TrainerCallback is the base class for trainer callbacks.
















   ..
       !! processed by numpydoc !!

   .. py:method:: on_train_begin(config: medkit.training.trainer_config.TrainerConfig)

      
      Event called at the beginning of training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_train_end()

      
      Event called at the end of training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_epoch_begin(epoch: int)

      
      Event called at the beginning of an epoch.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_epoch_end(metrics: dict[str, float], epoch: int, epoch_time: float)

      
      Event called at the end of an epoch.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_step_begin(step_idx: int, nb_batches: int, phase: str)

      
      Event called at the beginning of a step in training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_step_end(step_idx: int, nb_batches: int, phase: str)

      
      Event called at the end of a step in training.
















      ..
          !! processed by numpydoc !!


   .. py:method:: on_save(checkpoint_dir: str)

      
      Event called on saving a checkpoint.
















      ..
          !! processed by numpydoc !!


.. py:class:: TrainableComponent

   Bases: :py:obj:`typing_extensions.Protocol`


   
   TrainableComponent is the base protocol to be trainable in medkit.
















   ..
       !! processed by numpydoc !!

   .. py:property:: device
      :type: torch.device



   .. py:method:: configure_optimizer(lr: float) -> torch.optim.Optimizer

      
      Create optimizer using the learning rate.
















      ..
          !! processed by numpydoc !!


   .. py:method:: preprocess(data_item: Any) -> dict[str, Any]

      
      Run preprocessing on the input data.

      Preprocess the input data item and return a dictionary with
      everything needed for the forward pass.

      This method is intended to preprocess an input, `self.collate` must be
      used to generate batches for `self.forward` to run properly.
      Preprocess should include `labels` to compute a loss.















      ..
          !! processed by numpydoc !!


   .. py:method:: collate(batch: list[dict[str, Any]]) -> medkit.training.utils.BatchData

      
      Collate a list of data processed by `preprocess` to form a batch.
















      ..
          !! processed by numpydoc !!


   .. py:method:: forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) -> tuple[medkit.training.utils.BatchData, torch.Tensor | None]

      
      Perform the forward pass on a batch.

      Perform the forward pass on a batch and return the corresponding
      output as well as the loss if `return_loss` is True.

      Before forwarding the model, this method must set the model to training
      or evaluation mode depending on `eval_mode`. In PyTorch models there are
      two methods to set the mode `model.train()` and `model.eval()`.

      :Parameters:

          **input_batch** : BatchData
              Input batch

          **return_loss** : bool
              Whether to return the computed loss as well

          **eval_mode** : bool
              Whether to set the model to training (False) or evaluation mode (True)

      :Returns:

          **output** : BatchData
              Output after forward pass completion

          loss: torch.Tensor, optional
              Loss after forward pass completion, if `return_loss` was set to True.













      ..
          !! processed by numpydoc !!


   .. py:method:: save(path: str | pathlib.Path)

      
      Save model to disk.
















      ..
          !! processed by numpydoc !!


   .. py:method:: load(path: str | pathlib.Path)

      
      Load weights from disk.
















      ..
          !! processed by numpydoc !!


.. py:class:: Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)

   
   Class faciltating training and evaluation of PyTorch models to generate annotations.


   :Parameters:

       **component:**
           The component to train, the component must implement the `TrainableComponent` protocol.

       **config:**
           A `TrainerConfig` with the parameters for training, the parameter `output_dir` define the
           path of the checkpoints

       **train_data:**
           The data to use for training. This should be a corpus of medkit objects. The data could be,
           for instance, a `torch.utils.data.Dataset` that returns medkit objects for training.

       **eval_data:**
           The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects.
           The data can be a `torch.utils.data.Dataset` that returns medkit objects for evaluation.

       **metrics_computer:**
           Optional `MetricsComputer` object that will be used to compute custom metrics during eval.
           By default, only evaluation metrics will be computed, `do_metrics_in_training` in `config` allows
           metrics in training.

       **lr_scheduler_builder:**
           Optional function that build a `lr_scheduler` to adjust the learning rate after an epoch. Must take
           an Optimizer and return a `lr_scheduler`. If not provided, the learning rate does not change during
           training.

       **callback:**
           Optional callback to customize training.














   ..
       !! processed by numpydoc !!

   .. py:method:: get_dataloader(data: dict, shuffle: bool) -> torch.utils.data.DataLoader

      
      Return a DataLoader with transformations defined in the component to train.


      :Parameters:

          **data** : dict
              Training data

          **shuffle: bool**
              Whether to use sequential or shuffled sampling

      :Returns:

          torch.utils.data.DataLoader
              The corresponding instance of a DataLoader













      ..
          !! processed by numpydoc !!


   .. py:method:: training_epoch() -> dict[str, float]

      
      Perform an epoch using the training data.

      When the config enabled metrics in training ('do_metrics_in_training' is True),
      the additional metrics are prepared per batch.


      :Returns:

          dict of str to float
              A dictionary containing the training metrics













      ..
          !! processed by numpydoc !!


   .. py:method:: evaluation_epoch(eval_dataloader) -> dict[str, float]

      
      Perform an epoch using the evaluation data.

      The additional metrics are prepared per batch.

      :Parameters:

          **eval_dataloader** : torch.utils.data.DataLoader
              The evaluation dataset as a PyTorch DataLoader

      :Returns:

          dict of str to float
              A dictionary containing the evaluation metrics













      ..
          !! processed by numpydoc !!


   .. py:method:: make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) -> tuple[medkit.training.utils.BatchData, torch.Tensor]

      
      Run forward safely, same device as the component.
















      ..
          !! processed by numpydoc !!


   .. py:method:: update_learning_rate(eval_metrics: dict[str, float]) -> None

      
      Call the learning rate scheduler if defined.
















      ..
          !! processed by numpydoc !!


   .. py:method:: train() -> list[dict]

      
      Call the training and evaluation loop.



      :Returns:

          list of dict of str to float
              The list of computed metrics per epoch













      ..
          !! processed by numpydoc !!


   .. py:method:: save(epoch: int) -> str

      
      Save a checkpoint.

      Checkpoints include trainer configuration, model weights, optimizer and scheduler.

      :Parameters:

          **epoch** : int
              Epoch corresponding of the current training state
              (will be included in the checkpoint name)

      :Returns:

          str
              Path to the saved checkpoint













      ..
          !! processed by numpydoc !!


.. py:class:: TrainerConfig

   
   Trainer configuration.


   :Parameters:

       **output_dir:**
           The output directory where the checkpoint will be saved.

       **learning_rate:**
           The initial learning rate.

       **nb_training_epochs:**
           Total number of training/evaluation epochs to do.

       **dataloader_nb_workers:**
           Number of subprocess for the data loading. The default value is 0,
           the data will be loaded in the main process. If this config is for a
           HuggingFace model, do not change this value.

       **batch_size:**
           Number of samples per batch to load.

       **seed:**
           Random seed to use with PyTorch and numpy. It should be set to ensure
           reproducibility between experiments.

       **gradient_accumulation_steps:**
           Number of steps to accumulate gradient before performing an optimization step.

       **do_metrics_in_training:**
           By default, only the custom metrics are computed using `eval_data`. If set to
           True, the custom metrics are computed also using `training_data`.

       **metric_to_track_lr:**
           Name of the eval metric to be tracked for updating the learning rate.
           By default, eval `loss` is tracked.

       **checkpoint_period:**
           How often, in number of epochs, should we save a checkpoint. Use 0 to
           only save last checkpoint.

       **checkpoint_metric:**
           Name of the eval metric to be tracked for selecting the best checkpoint.
           By default, eval `loss` is tracked.

       **minimize_checkpoint_metric:**
           If `True`, the checkpoint with the lowest metric value will be selected
           as best, otherwise the checkpoint with the highest metric value.














   ..
       !! processed by numpydoc !!

   .. py:attribute:: output_dir
      :type:  str


   .. py:attribute:: learning_rate
      :type:  float
      :value: 1e-05



   .. py:attribute:: nb_training_epochs
      :type:  int
      :value: 3



   .. py:attribute:: dataloader_nb_workers
      :type:  int
      :value: 0



   .. py:attribute:: batch_size
      :type:  int
      :value: 1



   .. py:attribute:: seed
      :type:  int | None
      :value: None



   .. py:attribute:: gradient_accumulation_steps
      :type:  int
      :value: 1



   .. py:attribute:: do_metrics_in_training
      :type:  bool
      :value: False



   .. py:attribute:: metric_to_track_lr
      :type:  str
      :value: 'loss'



   .. py:attribute:: checkpoint_period
      :type:  int
      :value: 1



   .. py:attribute:: checkpoint_metric
      :type:  str
      :value: 'loss'



   .. py:attribute:: minimize_checkpoint_metric
      :type:  bool
      :value: True



   .. py:method:: to_dict() -> dict[str, Any]


.. py:class:: BatchData

   Bases: :py:obj:`dict`


   
   A BatchData pack data allowing both column and row access.
















   ..
       !! processed by numpydoc !!

   .. py:method:: __getitem__(index: int) -> dict[str, list[Any] | torch.Tensor]

      
      x.__getitem__(y) <==> x[y]
















      ..
          !! processed by numpydoc !!


   .. py:method:: to_device(device: torch.device) -> typing_extensions.Self

      
      Ensure that Tensors in the BatchData object are on the specified `device`.


      :Parameters:

          **device:**
              A `torch.device` object representing the device on which tensors
              will be allocated.

      :Returns:

          BatchData
              A new object with the tensors on the proper device.













      ..
          !! processed by numpydoc !!


.. py:class:: MetricsComputer

   Bases: :py:obj:`typing_extensions.Protocol`


   
   A MetricsComputer is the base protocol to compute metrics in training.
















   ..
       !! processed by numpydoc !!

   .. py:method:: prepare_batch(model_output: BatchData, input_batch: BatchData) -> dict[str, list[Any]]

      
      Prepare a batch of data to compute the metrics.


      :Parameters:

          **model_output: BatchData**
              Output data after a model forward pass.

          **input_batch: BatchData**
              Preprocessed input batch

      :Returns:

          dict[str, List[Any]]
              A dictionary with the required data to calculate the metric













      ..
          !! processed by numpydoc !!


   .. py:method:: compute(all_data: dict[str, list[Any]]) -> dict[str, float]

      
      Compute metrics using 'all_data'.


      :Parameters:

          **all_data: dict[str, List[Any]]**
              A dictionary to compute the metrics.
              i.e. A dictionary with a list of 'references' and a list of 'predictions'.

      :Returns:

          dict[str, float]
              A dictionary with the results













      ..
          !! processed by numpydoc !!


