:py:mod:`medkit.training.utils`
===============================

.. py:module:: medkit.training.utils


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   medkit.training.utils.BatchData
   medkit.training.utils.MetricsComputer




.. py:class:: BatchData


   Bases: :py:obj:`dict`

   
   A BatchData pack data allowing both column and row access.
















   ..
       !! processed by numpydoc !!
   .. py:method:: __getitem__(index: int) -> dict[str, list[Any] | torch.Tensor]

      
      x.__getitem__(y) <==> x[y]
















      ..
          !! processed by numpydoc !!

   .. py:method:: to_device(device: torch.device) -> typing_extensions.Self

      
      Ensure that Tensors in the BatchData object are on the specified `device`.


      :Parameters:

          **device:**
              A `torch.device` object representing the device on which tensors
              will be allocated.

      :Returns:

          BatchData
              A new object with the tensors on the proper device.













      ..
          !! processed by numpydoc !!


.. py:class:: MetricsComputer


   Bases: :py:obj:`typing_extensions.Protocol`

   
   A MetricsComputer is the base protocol to compute metrics in training.
















   ..
       !! processed by numpydoc !!
   .. py:method:: prepare_batch(model_output: BatchData, input_batch: BatchData) -> dict[str, list[Any]]

      
      Prepare a batch of data to compute the metrics.


      :Parameters:

          **model_output: BatchData**
              Output data after a model forward pass.

          **input_batch: BatchData**
              Preprocessed input batch

      :Returns:

          dict[str, List[Any]]
              A dictionary with the required data to calculate the metric













      ..
          !! processed by numpydoc !!

   .. py:method:: compute(all_data: dict[str, list[Any]]) -> dict[str, float]

      
      Compute metrics using 'all_data'.


      :Parameters:

          **all_data: dict[str, List[Any]]**
              A dictionary to compute the metrics.
              i.e. A dictionary with a list of 'references' and a list of 'predictions'.

      :Returns:

          dict[str, float]
              A dictionary with the results













      ..
          !! processed by numpydoc !!


