medkit.training.utils
=====================

.. py:module:: medkit.training.utils


Classes
-------

.. autoapisummary::

   medkit.training.utils.BatchData
   medkit.training.utils.MetricsComputer


Module Contents
---------------

.. py:class:: BatchData

   Bases: :py:obj:`dict`


   
   A BatchData pack data allowing both column and row access.
















   ..
       !! processed by numpydoc !!

   .. py:method:: __getitem__(index: int) -> dict[str, list[Any] | torch.Tensor]

      
      x.__getitem__(y) <==> x[y]
















      ..
          !! processed by numpydoc !!


   .. py:method:: to_device(device: torch.device) -> typing_extensions.Self

      
      Ensure that Tensors in the BatchData object are on the specified `device`.


      :Parameters:

          **device:**
              A `torch.device` object representing the device on which tensors
              will be allocated.



      :Returns:

          BatchData
              A new object with the tensors on the proper device.











      ..
          !! processed by numpydoc !!


.. py:class:: MetricsComputer

   Bases: :py:obj:`typing_extensions.Protocol`


   
   A MetricsComputer is the base protocol to compute metrics in training.
















   ..
       !! processed by numpydoc !!

   .. py:method:: prepare_batch(model_output: BatchData, input_batch: BatchData) -> dict[str, list[Any]]

      
      Prepare a batch of data to compute the metrics.


      :Parameters:

          **model_output: BatchData**
              Output data after a model forward pass.

          **input_batch: BatchData**
              Preprocessed input batch



      :Returns:

          dict[str, List[Any]]
              A dictionary with the required data to calculate the metric











      ..
          !! processed by numpydoc !!


   .. py:method:: compute(all_data: dict[str, list[Any]]) -> dict[str, float]

      
      Compute metrics using 'all_data'.


      :Parameters:

          **all_data: dict[str, List[Any]]**
              A dictionary to compute the metrics.
              i.e. A dictionary with a list of 'references' and a list of 'predictions'.



      :Returns:

          dict[str, float]
              A dictionary with the results











      ..
          !! processed by numpydoc !!


