Source code for medkit.audio.transcription.hf_transcriber_function

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[hf-transcriber_function]`.
"""

__all__ = ["HFTranscriberFunction"]

from pathlib import Path
from typing import List, Optional, Union

import transformers
from transformers import AutomaticSpeechRecognitionPipeline

from medkit.core.audio import AudioBuffer
from medkit.audio.transcription.doc_transcriber import TranscriberFunctionDescription


[docs]class HFTranscriberFunction: """Transcriber function based on a Hugging Face transformers model. To be used within a :class:`~medkit.audio.transcription.doc_transcriber.DocTranscriber` """ def __init__( self, model: str = "facebook/s2t-large-librispeech-asr", add_trailing_dot: bool = True, capitalize: bool = True, device: int = -1, batch_size: int = 1, cache_dir: Optional[Union[str, Path]] = None, ): """ Parameters ---------- model: Name of the ASR model on the Hugging Face models hub. Must be a model compatible with the `AutomaticSpeechRecognitionPipeline` transformers class. add_trailing_dot: If `True`, a dot will be added at the end of each transcription text. capitalize: It `True`, the first letter of each transcription text will be uppercased and the rest lowercased. device: Device to use for pytorch models. Follows the Hugging Face convention (`-1` for cpu and device number for gpu, for instance `0` for "cuda:0") batch_size: Size of batches processed by ASR pipeline. cache_dir: Directory where to store downloaded models. If not set, the default HuggingFace cache dir is used. """ self.model_name = model self.add_trailing_dot = add_trailing_dot self.capitalize = capitalize self.device = device task = transformers.pipelines.get_task(self.model_name) if not task == "automatic-speech-recognition": raise ValueError( f"Model {self.model_name} is not associated to a speech" " recognition task and cannot be use with HFTranscriberFunction" ) self._pipeline = transformers.pipeline( task=task, model=self.model_name, feature_extractor=self.model_name, pipeline_class=AutomaticSpeechRecognitionPipeline, device=self.device, batch_size=batch_size, model_kwargs={"cache_dir": cache_dir}, ) @property def description(self) -> TranscriberFunctionDescription: config = dict( model_name=self.model_name, add_trailing_dot=self.add_trailing_dot, capitalize=self.capitalize, device=self.device, ) return TranscriberFunctionDescription( name=self.__class__.__name__, config=config ) def transcribe(self, audios: List[AudioBuffer]) -> List[str]: audio_dicts_gen = ( { "raw": audio.read().reshape((-1,)), "sampling_rate": audio.sample_rate, } for audio in audios ) text_dicts = self._pipeline(audio_dicts_gen) texts_gen = (text_dict["text"] for text_dict in text_dicts) if self.capitalize and self.add_trailing_dot: texts = [t.capitalize() + "." for t in texts_gen] elif self.capitalize: texts = [t.capitalize() for t in texts_gen] elif self.add_trailing_dot: texts = [t + "." for t in texts_gen] else: texts = list(texts_gen) return texts