Source code for medkit.audio.segmentation.pa_speaker_detector

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[pa-speaker-detector]`.
"""

__all__ = ["PASpeakerDetector"]

from pathlib import Path
from typing import Iterator, List, Optional, Union

# When pyannote and spacy are both installed, a conflict might occur between the
# ujson library used by pandas (a pyannote dependency) and the ujson library used
# by srsrly (a spacy dependency), especially in docker environments.
# srsly seems to end up using the ujson library from pandas, which is older and does not
# support the escape_forward_slashes parameters, instead of its own.
# The bug seems to only happen when pandas is imported from pyannote, not if
# we import pandas manually first.
# So as a workaround, we always import pandas before importing something from pyannote
import pandas  # noqa: F401
from pyannote.audio import Pipeline
from pyannote.audio.pipelines import SpeakerDiarization
import torch

from medkit.core import Attribute
from medkit.core.audio import SegmentationOperation, Segment, Span


[docs]class PASpeakerDetector(SegmentationOperation): """Speaker diarization operation relying on `pyannote.audio` Each input segment will be split into several sub-segments corresponding to speech turn, and an attribute will be attached to each of these sub-segments indicating the speaker of the turn. `PASpeakerDetector` uses the `SpeakerDiarization` pipeline from `pyannote.audio`, which performs the following steps: - perform multi-speaker VAD with a `PyanNet` segmentation model and extract \ voiced segments ; - compute embeddings for each voiced segment with a \ embeddings model (typically speechbrain ECAPA-TDNN) ; - group voice segments by speakers using a clustering algorithm such as agglomerative clustering, HMM, etc. """ def __init__( self, model: Union[str, Path], output_label: str, min_nb_speakers: Optional[int] = None, max_nb_speakers: Optional[int] = None, min_duration: float = 0.1, device: int = -1, segmentation_batch_size: int = 1, embedding_batch_size: int = 1, hf_auth_token: Optional[str] = None, uid: Optional[str] = None, ): """ Parameters ---------- model: Name (on the HuggingFace models hub) or path of a pretrained pipeline. When a path, should point to the .yaml file containing the pipeline configuration. output_label: Label of generated turn segments. min_nb_speakers: Minimum number of speakers expected to be found. max_nb_speakers: Maximum number of speakers expected to be found. min_duration: Minimum duration of speech segments, in seconds (short segments will be discarded). device: Device to use for pytorch models. Follows the Hugging Face convention (`-1` for cpu and device number for gpu, for instance `0` for "cuda:0"). segmentation_batch_size: Number of input segments in batches processed by segmentation model. embedding_batch_size: Number of pre-segmented audios in batches processed by embedding model. hf_auth_token: HuggingFace Authentication token (to access private models on the hub) uid: Identifier of the detector. """ # Pass all arguments to super (remove self and confidential hf_auth_token) init_args = locals() init_args.pop("self") init_args.pop("hf_auth_token") super().__init__(**init_args) self.output_label = output_label self.min_nb_speakers = min_nb_speakers self.max_nb_speakers = max_nb_speakers self.min_duration = min_duration torch_device = torch.device("cpu" if device < 0 else f"cuda:{device}") self._pipeline = Pipeline.from_pretrained(model, use_auth_token=hf_auth_token) if self._pipeline is None: raise Exception(f"Could not instantiate pretrained pipeline with '{model}'") if not isinstance(self._pipeline, SpeakerDiarization): raise Exception( f"'{model}' does not correspond to a SpeakerDiarization pipeline. Got" f" object of type {type(self._pipeline)}" ) self._pipeline.to(torch_device) self._pipeline.segmentation_batch_size = segmentation_batch_size self._pipeline.embedding_batch_size = embedding_batch_size
[docs] def run(self, segments: List[Segment]) -> List[Segment]: """Return all turn segments detected for all input `segments`. Parameters ---------- segments: Audio segments on which to perform diarization. Returns ------- List[~medkit.core.audio.Segment]: Segments detected as containing speech activity (with speaker attributes) """ return [ turn_seg for seg in segments for turn_seg in self._detect_turns_in_segment(seg) ]
def _detect_turns_in_segment(self, segment: Segment) -> Iterator[Segment]: audio = segment.audio file = { "waveform": torch.from_numpy(audio.read()), "sample_rate": audio.sample_rate, } diarization = self._pipeline.apply( file, min_speakers=self.min_nb_speakers, max_speakers=self.max_nb_speakers, ) for turn, _, speaker in diarization.itertracks(yield_label=True): if turn.duration < self.min_duration: continue # trim original audio to turn start/end points turn_audio = audio.trim_duration(turn.start, turn.end) turn_span = Span( start=segment.span.start + turn.start, end=segment.span.start + turn.end, ) speaker_attr = Attribute(label="speaker", value=speaker) turn_segment = Segment( label=self.output_label, span=turn_span, audio=turn_audio, attrs=[speaker_attr], ) if self._prov_tracer is not None: self._prov_tracer.add_prov(turn_segment, self.description, [segment]) self._prov_tracer.add_prov(speaker_attr, self.description, [segment]) yield turn_segment