Source code for medkit.audio.transcription.sb_transcriber_function

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[sb-transcriber_function]`.
"""

__all__ = ["SBTranscriberFunction"]

from pathlib import Path
from typing import List, Optional, Union

import speechbrain as sb

from medkit.core.audio import AudioBuffer
import medkit.core.utils
from medkit.audio.transcription.doc_transcriber import TranscriberFunctionDescription


[docs]class SBTranscriberFunction: """Transcriber function based on a SpeechBrain model. To be used within a :class:`~medkit.audio.transcription.doc_transcriber.DocTranscriber` """ def __init__( self, model: Union[str, Path], needs_decoder: bool, add_trailing_dot: bool = True, capitalize: bool = True, cache_dir: Optional[Union[str, Path]] = None, device: int = -1, batch_size: int = 1, ): """ Parameters ---------- model: Name of the model on the Hugging Face models hub, or local path. needs_decoder: Whether the model should be used with the speechbrain `EncoderDecoderASR` class or the `EncoderASR` class. If unsure, check the code snippets on the model card on the hub. add_trailing_dot: If `True`, a dot will be added at the end of each transcription text. capitalize: It `True`, the first letter of each transcription text will be uppercased and the rest lowercased. cache_dir: Directory where to store the downloaded model. If `None`, speechbrain will use "pretrained_models/" and "model_checkpoints/" directories in the current working directory. device: Device to use for pytorch models. Follows the Hugging Face convention (`-1` for cpu and device number for gpu, for instance `0` for "cuda:0") batch_size: Number of segments in batches processed by the model. """ if cache_dir is not None: cache_dir = Path(cache_dir) self.model_name = model self.add_trailing_dot = add_trailing_dot self.capitalize = capitalize self.cache_dir = cache_dir self.device = device self.batch_size = batch_size self._torch_device = "cpu" if self.device < 0 else f"cuda:{self.device}" asr_class = ( sb.pretrained.EncoderDecoderASR if needs_decoder else sb.pretrained.EncoderASR ) self._asr = asr_class.from_hparams( source=model, savedir=cache_dir, run_opts={"device": self._torch_device} ) self._sample_rate = self._asr.audio_normalizer.sample_rate @property def description(self) -> TranscriberFunctionDescription: config = dict( model_name=self.model_name, has_decoder=self.has_decoder, add_trailing_dot=self.add_trailing_dot, capitalize=self.capitalize, cache_dir=self.cache_dir, device=self.device, batch_size=self.batch_size, ) return TranscriberFunctionDescription( name=self.__class__.__name__, config=config ) def transcribe(self, audios: List[AudioBuffer]) -> List[str]: if not all(a.sample_rate == self._sample_rate for a in audios): raise ValueError( "SBTranscriberFunction received audio buffers with incompatible sample" f" rates (model expected {self._sample_rate} Hz)" ) if not all(a.nb_channels == 1 for a in audios): raise ValueError("SBTranscriberFunction only supports mono audio buffers") texts = [] for batched_audios in medkit.core.utils.batch_list(audios, self.batch_size): texts += self._transcribe_audios(batched_audios) return texts def _transcribe_audios(self, audios: List[AudioBuffer]) -> List[str]: # group audios in batch of same length with padding padded_batch = sb.dataio.batch.PaddedBatch( [{"wav": a.read().reshape((-1,))} for a in audios] ) padded_batch.to(self._torch_device) texts, _ = self._asr.transcribe_batch( padded_batch.wav.data, padded_batch.wav.lengths ) if self.capitalize: texts = [t.capitalize() for t in texts] if self.add_trailing_dot: texts = [t + "." for t in texts] return texts