Source code for medkit.training.utils

from __future__ import annotations

__all__ = ["BatchData", "MetricsComputer"]

from typing import Any, Dict, List, Union
from typing_extensions import Protocol, runtime_checkable

import torch


[docs]class BatchData(dict): """A BatchData pack data allowing both column and row access""" def __getitem__(self, index: int) -> Dict[str, Union[List[Any], torch.Tensor]]: if isinstance(index, str): inner_dict = {key: values for (key, values) in self.items()} return inner_dict[index] return {key: values[index] for key, values in self.items()}
[docs] def to_device(self, device: torch.device) -> BatchData: """ Ensure that Tensors in the BatchData object are on the specified `device` Parameters ---------- device: A `torch.device` object representing the device on which tensors will be allocated. Returns ------- BatchData A new object with the tensors on the proper device. """ inner_batch = BatchData() for key, value in self.items(): if isinstance(value, torch.Tensor): inner_batch[key] = value.to(device) else: inner_batch[key] = value return inner_batch
[docs]@runtime_checkable class MetricsComputer(Protocol): "A MetricsComputer is the base protocol to compute metrics in training"
[docs] def prepare_batch( self, model_output: BatchData, input_batch: BatchData ) -> Dict[str, List[Any]]: """Prepare a batch of data to compute the metrics Parameters ---------- model_output: BatchData Output data after a model forward pass. input_batch: BatchData Preprocessed input batch Returns ------- Dict[str, List[Any]] A dictionary with the required data to calculate the metric """ pass
[docs] def compute(self, all_data: Dict[str, List[Any]]) -> Dict[str, float]: """Compute metrics using 'all_data' Parameters ---------- all_data: Dict[str, List[Any]] A dictionary to compute the metrics. i.e. A dictionary with a list of 'references' and a list of 'predictions'. Returns ------- Dict[str, float] A dictionary with the results """ pass