Source code for medkit.audio.transcription.hf_transcriber

"""This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[hf-transcriber]`.
"""
from __future__ import annotations

__all__ = ["HFTranscriber"]


from typing import TYPE_CHECKING

import transformers
from transformers import AutomaticSpeechRecognitionPipeline

from medkit.core import Attribute, Operation

if TYPE_CHECKING:
    from pathlib import Path

    from medkit.core.audio import AudioBuffer, Segment


[docs] class HFTranscriber(Operation): """Transcriber operation based on a Hugging Face transformers model. For each segment given as input, a transcription attribute will be created with the transcribed text as value. If needed, a text document can later be created from all the transcriptions of a audio document using :func:`~medkit.audio.transcription.TranscribedTextDocument.from_audio_doc <TranscribedTextDocument.from_audio_doc>` """ def __init__( self, model: str = "facebook/s2t-large-librispeech-asr", output_label: str = "transcribed_text", language: str | None = None, add_trailing_dot: bool = True, capitalize: bool = True, device: int = -1, batch_size: int = 1, hf_auth_token: str | None = None, cache_dir: str | Path | None = None, uid: str | None = None, ): """Parameters ---------- model : str, default="facebook/s2t-large-librispeech-asr" Name of the ASR model on the Hugging Face models hub. Must be a model compatible with the `AutomaticSpeechRecognitionPipeline` transformers class. output_label : str, default="transcribed_text" Label of the attribute containing the transcribed text that will be attached to the input segments language : str, optional Optional output language to be forced on the model (useful for some multilingual models such as Whisper) add_trailing_dot : bool, default=True If `True`, a dot will be added at the end of each transcription text. capitalize : bool, default=True It `True`, the first letter of each transcription text will be uppercased and the rest lowercased. device : int, default=-1 Device to use for pytorch models. Follows the Hugging Face convention (`-1` for cpu and device number for gpu, for instance `0` for "cuda:0") batch_size : int, default=1 Size of batches processed by ASR pipeline. hf_auth_token : str, optional HuggingFace Authentication token (to access private models on the hub) cache_dir : str or Path, optional Directory where to store downloaded models. If not set, the default HuggingFace cache dir is used. uid : str, optional Identifier of the transcriber. """ super().__init__( model=model, output_label=output_label, add_trailing_dot=add_trailing_dot, capitalize=capitalize, device=device, batch_size=batch_size, cache_dir=cache_dir, uid=uid, ) self.model_name = model self.output_label = output_label self.add_trailing_dot = add_trailing_dot self.capitalize = capitalize self.device = device task = transformers.pipelines.get_task(self.model_name, token=hf_auth_token) if task != "automatic-speech-recognition": msg = ( f"Model {self.model_name} is not associated to a speech" " recognition task and cannot be use with HFTranscriber" ) raise ValueError(msg) self._pipeline = transformers.pipeline( task=task, model=self.model_name, feature_extractor=self.model_name, pipeline_class=AutomaticSpeechRecognitionPipeline, device=self.device, batch_size=batch_size, token=hf_auth_token, model_kwargs={"cache_dir": cache_dir}, ) if language is not None: self._pipeline.model.config.forced_decoder_ids = self._pipeline.tokenizer.get_decoder_prompt_ids( language=language, task="transcribe" )
[docs] def run(self, segments: list[Segment]): """Add a transcription attribute to each segment with a text value containing the transcribed text. Parameters ---------- segments : list of Segment List of segments to transcribe """ audios = [s.audio for s in segments] texts = self._transcribe_audios(audios) for segment, text in zip(segments, texts): attr = Attribute(label=self.output_label, value=text) segment.attrs.add(attr) if self._prov_tracer is not None: self._prov_tracer.add_prov(attr, self.description, [segment])
def _transcribe_audios(self, audios: list[AudioBuffer]) -> list[str]: # generate iterator of all audio dicts to pass to the transformers # pipeline (which will handle the batching) audio_dicts_gen = ( { "raw": audio.read().reshape((-1,)), "sampling_rate": audio.sample_rate, } for audio in audios ) text_dicts = self._pipeline(audio_dicts_gen) texts_gen = (text_dict["text"] for text_dict in text_dicts) # post-process transcribed texts if self.capitalize and self.add_trailing_dot: texts = [t.capitalize() + "." for t in texts_gen] elif self.capitalize: texts = [t.capitalize() for t in texts_gen] elif self.add_trailing_dot: texts = [t + "." for t in texts_gen] else: texts = list(texts_gen) return texts