Source code for medkit.training.utils
from __future__ import annotations
__all__ = ["BatchData", "MetricsComputer"]
from typing import Any, runtime_checkable
import torch
from typing_extensions import Protocol, Self
[docs]
class BatchData(dict):
"""A BatchData pack data allowing both column and row access"""
def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]:
if isinstance(index, str):
inner_dict = dict(self.items())
return inner_dict[index]
return {key: values[index] for key, values in self.items()}
[docs]
def to_device(self, device: torch.device) -> Self:
"""Ensure that Tensors in the BatchData object are on the specified `device`
Parameters
----------
device:
A `torch.device` object representing the device on which tensors
will be allocated.
Returns
-------
BatchData
A new object with the tensors on the proper device.
"""
inner_batch = BatchData()
for key, value in self.items():
if isinstance(value, torch.Tensor):
inner_batch[key] = value.to(device)
else:
inner_batch[key] = value
return inner_batch
[docs]
@runtime_checkable
class MetricsComputer(Protocol):
"A MetricsComputer is the base protocol to compute metrics in training"
[docs]
def prepare_batch(self, model_output: BatchData, input_batch: BatchData) -> dict[str, list[Any]]:
"""Prepare a batch of data to compute the metrics
Parameters
----------
model_output: BatchData
Output data after a model forward pass.
input_batch: BatchData
Preprocessed input batch
Returns
-------
dict[str, List[Any]]
A dictionary with the required data to calculate the metric
"""
[docs]
def compute(self, all_data: dict[str, list[Any]]) -> dict[str, float]:
"""Compute metrics using 'all_data'
Parameters
----------
all_data: dict[str, List[Any]]
A dictionary to compute the metrics.
i.e. A dictionary with a list of 'references' and a list of 'predictions'.
Returns
-------
dict[str, float]
A dictionary with the results
"""