:py:mod:`medkit.training.trainable_component`
=============================================

.. py:module:: medkit.training.trainable_component


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   medkit.training.trainable_component.TrainableComponent




.. py:class:: TrainableComponent


   Bases: :py:obj:`typing_extensions.Protocol`

   
   TrainableComponent is the base protocol to be trainable in medkit.
















   ..
       !! processed by numpydoc !!
   .. py:property:: device
      :type: torch.device


   .. py:method:: configure_optimizer(lr: float) -> torch.optim.Optimizer

      
      Create optimizer using the learning rate.
















      ..
          !! processed by numpydoc !!

   .. py:method:: preprocess(data_item: Any) -> dict[str, Any]

      
      Run preprocessing on the input data.

      Preprocess the input data item and return a dictionary with
      everything needed for the forward pass.

      This method is intended to preprocess an input, `self.collate` must be
      used to generate batches for `self.forward` to run properly.
      Preprocess should include `labels` to compute a loss.















      ..
          !! processed by numpydoc !!

   .. py:method:: collate(batch: list[dict[str, Any]]) -> medkit.training.utils.BatchData

      
      Collate a list of data processed by `preprocess` to form a batch.
















      ..
          !! processed by numpydoc !!

   .. py:method:: forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) -> tuple[medkit.training.utils.BatchData, torch.Tensor | None]

      
      Perform the forward pass on a batch.

      Perform the forward pass on a batch and return the corresponding
      output as well as the loss if `return_loss` is True.

      Before forwarding the model, this method must set the model to training
      or evaluation mode depending on `eval_mode`. In PyTorch models there are
      two methods to set the mode `model.train()` and `model.eval()`.

      :Parameters:

          **input_batch** : BatchData
              Input batch

          **return_loss** : bool
              Whether to return the computed loss as well

          **eval_mode** : bool
              Whether to set the model to training (False) or evaluation mode (True)

      :Returns:

          **output** : BatchData
              Output after forward pass completion

          loss: torch.Tensor, optional
              Loss after forward pass completion, if `return_loss` was set to True.













      ..
          !! processed by numpydoc !!

   .. py:method:: save(path: str | pathlib.Path)

      
      Save model to disk.
















      ..
          !! processed by numpydoc !!

   .. py:method:: load(path: str | pathlib.Path)

      
      Load weights from disk.
















      ..
          !! processed by numpydoc !!


