Source code for medkit.training.trainer

from __future__ import annotations

__all__ = ["Trainer"]

import datetime
import random
import shutil
import time
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
import yaml
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

from medkit.training.callbacks import DefaultPrinterCallback, TrainerCallback

if TYPE_CHECKING:
    from medkit.training.trainable_component import TrainableComponent
    from medkit.training.trainer_config import TrainerConfig
    from medkit.training.utils import BatchData, MetricsComputer

# checkpoint constants
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
CONFIG_NAME = "trainer_config.yml"


def set_seed(seed: int = 0):
    """Set seed to keep deterministic operations"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class _TrainerDataset(Dataset):
    """A Dataset that preprocesses data using the 'preprocess' defined in a trainable component.
    This class is inspired from the ``PipelineDataset`` class from hugginface transformers library.
    """

    def __init__(self, dataset, component: TrainableComponent):
        self.dataset = dataset
        self.component = component

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        item = self.dataset[i]
        return self.component.preprocess(item)


[docs] class Trainer: """A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations """ def __init__( self, component: TrainableComponent, config: TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: TrainerCallback | None = None, ): """Parameters ---------- component: The component to train, the component must implement the `TrainableComponent` protocol. config: A `TrainerConfig` with the parameters for training, the parameter `output_dir` define the path of the checkpoints train_data: The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a `torch.utils.data.Dataset` that returns medkit objects for training. eval_data: The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a `torch.utils.data.Dataset` that returns medkit objects for evaluation. metrics_computer: Optional `MetricsComputer` object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, `do_metrics_in_training` in `config` allows metrics in training. lr_scheduler_builder: Optional function that build a `lr_scheduler` to adjust the learning rate after an epoch. Must take an Optimizer and return a `lr_scheduler`. If not provided, the learning rate does not change during training. callback: Optional callback to customize training. """ # enable deterministic operation if config.seed is not None: set_seed(config.seed) self.output_dir = Path(config.output_dir) self.output_dir.mkdir(exist_ok=True) self.component = component self.batch_size = config.batch_size self.dataloader_drop_last = False self.dataloader_nb_workers = config.dataloader_nb_workers self.dataloader_pin_memory = False self.device = self.component.device self.train_dataloader = self.get_dataloader(train_data, shuffle=True) self.eval_dataloader = self.get_dataloader(eval_data, shuffle=False) self.nb_training_epochs = config.nb_training_epochs self.config = config self.optimizer = component.configure_optimizer(self.config.learning_rate) self.lr_scheduler = None if lr_scheduler_builder is None else lr_scheduler_builder(self.optimizer) self.metrics_computer = metrics_computer if callback is None: callback = DefaultPrinterCallback() self.callback = callback
[docs] def get_dataloader(self, data: any, shuffle: bool) -> DataLoader: """Return a DataLoader with transformations defined in the component to train """ dataset = _TrainerDataset(data, self.component) collate_fn = self.component.collate return DataLoader( dataset, batch_size=self.batch_size, shuffle=shuffle, collate_fn=collate_fn, drop_last=self.dataloader_drop_last, num_workers=self.dataloader_nb_workers, pin_memory=self.dataloader_pin_memory, )
[docs] def training_epoch(self) -> dict[str, float]: """Perform an epoch using the training data. When the config enabled metrics in training ('do_metrics_in_training' is True), the additional metrics are prepared per batch. Return a dictionary with metrics. """ config = self.config total_loss_epoch = 0.0 metrics = {} data_for_metrics = defaultdict(list) for step, input_batch in enumerate(self.train_dataloader): self.callback.on_step_begin(step, nb_batches=len(self.train_dataloader), phase="train") model_output, loss = self.make_forward_pass(input_batch, eval_mode=False) if config.gradient_accumulation_steps > 1: loss = loss / config.gradient_accumulation_steps loss.backward() if ((step + 1) % config.gradient_accumulation_steps == 0) or (step + 1 == len(self.train_dataloader)): self.optimizer.step() self.optimizer.zero_grad() total_loss_epoch += loss.item() if config.do_metrics_in_training and self.metrics_computer is not None: prepared_batch = self.metrics_computer.prepare_batch(model_output, input_batch) for key, values in prepared_batch.items(): data_for_metrics[key].extend(values) self.callback.on_step_end(step, nb_batches=len(self.train_dataloader), phase="train") total_loss_epoch /= len(self.train_dataloader) metrics["loss"] = total_loss_epoch if config.do_metrics_in_training and self.metrics_computer is not None: metrics.update(self.metrics_computer.compute(dict(data_for_metrics))) return metrics
[docs] def evaluation_epoch(self, eval_dataloader) -> dict[str, float]: """Perform an epoch using the evaluation data. The additional metrics are prepared per batch. Return a dictionary with metrics. """ total_loss_epoch = 0.0 metrics = {} data_for_metrics = defaultdict(list) with torch.no_grad(): for step, input_batch in enumerate(eval_dataloader): self.callback.on_step_begin(step, nb_batches=len(eval_dataloader), phase="eval") model_output, loss = self.make_forward_pass(input_batch, eval_mode=True) total_loss_epoch += loss.item() if self.metrics_computer is not None: prepared_batch = self.metrics_computer.prepare_batch(model_output, input_batch) for key, values in prepared_batch.items(): data_for_metrics[key].extend(values) self.callback.on_step_end(step, nb_batches=len(eval_dataloader), phase="eval") total_loss_epoch /= len(self.eval_dataloader) metrics["loss"] = total_loss_epoch if self.metrics_computer is not None: metrics.update(self.metrics_computer.compute(dict(data_for_metrics))) return metrics
[docs] def make_forward_pass(self, inputs: BatchData, eval_mode: bool) -> tuple[BatchData, torch.Tensor]: """Run forward safely, same device as the component""" inputs = inputs.to_device(self.device) model_output, loss = self.component.forward(inputs, return_loss=True, eval_mode=eval_mode) if loss is None: msg = "The component did not return a 'loss' from the input." raise ValueError(msg) return model_output, loss
[docs] def update_learning_rate(self, eval_metrics: dict[str, float]): """Call the learning rate scheduler if defined""" if self.lr_scheduler is None: return if isinstance(self.lr_scheduler, ReduceLROnPlateau): name_metric_to_track_lr = self.config.metric_to_track_lr eval_metric = eval_metrics.get(name_metric_to_track_lr) if eval_metric is None: msg = ( "Learning scheduler needs an eval metric to update the learning" f" rate, '{name_metric_to_track_lr}' was not found" ) raise ValueError(msg) self.lr_scheduler.step(eval_metric) else: self.lr_scheduler.step()
[docs] def train(self) -> list[dict]: """Main training method. Call the training / eval loop. Return a list with the metrics per epoch. """ self.callback.on_train_begin(config=self.config) log_history = [] last_checkpoint_dir = None best_checkpoint_dir = None best_checkpoint_metric = None for epoch in range(1, self.nb_training_epochs + 1): epoch_start_time = time.time() self.callback.on_epoch_begin(epoch=epoch) train_metrics = self.training_epoch() eval_metrics = self.evaluation_epoch(self.eval_dataloader) self.update_learning_rate(eval_metrics) metrics = {"train": train_metrics, "eval": eval_metrics} log_history.append(metrics) self.callback.on_epoch_end( metrics=metrics, epoch=epoch, epoch_duration=time.time() - epoch_start_time, ) # save checkpoint every N epochs if N != 0, or at last epoch if epoch != self.nb_training_epochs and ( self.config.checkpoint_period == 0 or epoch % self.config.checkpoint_period != 0 ): continue # save last checkpoint last_checkpoint_dir = self.save(epoch) # track best checkpoint, and remove former best checkpoint if last # checkpoint is the new best last_checkpoint_metric = metrics["eval"].get(self.config.checkpoint_metric) if last_checkpoint_metric is None: msg = f"Checkpoint metric '{self.config.checkpoint_metric}' not found" raise ValueError(msg) if best_checkpoint_dir is None: best_checkpoint_dir = last_checkpoint_dir best_checkpoint_metric = last_checkpoint_metric elif (self.config.minimize_checkpoint_metric and last_checkpoint_metric < best_checkpoint_metric) or ( not self.config.minimize_checkpoint_metric and last_checkpoint_metric > best_checkpoint_metric ): shutil.rmtree(best_checkpoint_dir) best_checkpoint_dir = last_checkpoint_dir best_checkpoint_metric = last_checkpoint_metric self.callback.on_train_end() return log_history
[docs] def save(self, epoch: int) -> str: """Save a checkpoint (trainer configuration, model weights, optimizer and scheduler) Parameters ---------- epoch : int Epoch corresponding of the current training state (will be included in the checkpoint name) Returns ------- str Path of the checkpoint saved """ current_date = datetime.datetime.now().strftime("%d-%m-%Y_%H:%M") name = f"checkpoint_{epoch:03d}_{current_date}" checkpoint_dir = Path(self.output_dir) / name self.callback.on_save(checkpoint_dir=str(checkpoint_dir)) checkpoint_dir.mkdir() # save config config_path = checkpoint_dir / CONFIG_NAME with config_path.open(mode="w") as fp: yaml.safe_dump( self.config.to_dict(), fp, encoding="utf-8", allow_unicode=True, sort_keys=False, ) torch.save(self.optimizer.state_dict(), checkpoint_dir / OPTIMIZER_NAME) if self.lr_scheduler is not None: torch.save(self.lr_scheduler.state_dict(), checkpoint_dir / SCHEDULER_NAME) self.component.save(checkpoint_dir) return str(checkpoint_dir)