from __future__ import annotations
__all__ = ["TrainerCallback", "DefaultPrinterCallback"]
import logging
from typing import Dict
from tqdm import tqdm
from medkit.training.trainer_config import TrainerConfig
[docs]class TrainerCallback:
"""A TrainerCallback is the base class for trainer callbacks"""
[docs] def on_train_begin(self, config: TrainerConfig):
"""Event called at the beginning of training"""
pass
[docs] def on_train_end(self):
"""Event called at the end of training"""
pass
[docs] def on_epoch_begin(self, epoch: int):
"""Event called at the beginning of an epoch"""
pass
[docs] def on_epoch_end(self, metrics: Dict[str, float], epoch: int, epoch_time: float):
"""Event called at the end of an epoch"""
pass
[docs] def on_step_begin(self, step_idx: int, nb_batches: int, phase: str):
"""Event called at the beginning of a step in training"""
pass
[docs] def on_step_end(self, step_idx: int, nb_batches: int, phase: str):
"""Event called at the end of a step in training"""
pass
[docs] def on_save(self, checkpoint_dir: str):
"""Event called on saving a checkpoint"""
pass
[docs]class DefaultPrinterCallback(TrainerCallback):
"""Default implementation of :class:`~.training.TrainerCallback`"""
def __init__(self):
self.logger = logging.getLogger(__class__.__name__)
self.logger.setLevel(logging.INFO)
# define handler and formatter
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
console_handler.setFormatter(formatter)
# ensure a single handler for the logger
for handler in self.logger.handlers:
self.logger.removeHandler(handler)
self.logger.addHandler(console_handler)
self._progress_bar = None
[docs] def on_train_begin(self, config):
message = (
"Running training:\n"
+ f" Num epochs: {config.nb_training_epochs}\n"
+ f" Train batch size:{config.batch_size}\n"
+ f" Gradient accum steps: {config.gradient_accumulation_steps}\n"
)
self.logger.info(message)
[docs] def on_epoch_end(self, metrics, epoch, epoch_duration):
message = f"Epoch {epoch} ended (duration: {epoch_duration:.2f}s)\n"
train_metrics = metrics.get("train", None)
if train_metrics is not None:
message += (
"Training metrics:\n "
+ "\n ".join(
f"{metric_key}:{value:8.3f}"
for metric_key, value in train_metrics.items()
)
+ "\n"
)
eval_metrics = metrics.get("eval", None)
if eval_metrics is not None:
message += (
"Evaluation metrics:\n "
+ "\n ".join(
f"{metric_key}:{value:8.3f}"
for metric_key, value in eval_metrics.items()
)
+ "\n"
)
self.logger.info(message)
[docs] def on_train_end(self):
self.logger.info("Training is completed")
[docs] def on_save(self, checkpoint_dir):
self.logger.info(f"Saving checkpoint in {checkpoint_dir}")
[docs] def on_step_begin(self, step_idx: int, nb_batches: int, phase: str):
if step_idx == 0:
assert self._progress_bar is None
self._progress_bar = tqdm(total=nb_batches)
self._progress_bar.set_description(phase)
[docs] def on_step_end(self, step_idx: int, nb_batches: int, phase: str):
self._progress_bar.update()
if step_idx + 1 == nb_batches:
self._progress_bar.close()
self._progress_bar = None