"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[pa-speaker-detector]`.
"""
__all__ = ["PASpeakerDetector"]
from pathlib import Path
from typing import Iterator, List, Optional, Union
# When pyannote and spacy are both installed, a conflict might occur between the
# ujson library used by pandas (a pyannote dependency) and the ujson library used
# by srsrly (a spacy dependency), especially in docker environments.
# srsly seems to end up using the ujson library from pandas, which is older and does not
# support the escape_forward_slashes parameters, instead of its own.
# The bug seems to only happen when pandas is imported from pyannote, not if
# we import pandas manually first.
# So as a workaround, we always import pandas before importing something from pyannote
import pandas # noqa: F401
from pyannote.audio import Pipeline
from pyannote.audio.pipelines import SpeakerDiarization
import torch
from medkit.core import Attribute
from medkit.core.audio import SegmentationOperation, Segment, Span
[docs]class PASpeakerDetector(SegmentationOperation):
"""Speaker diarization operation relying on `pyannote.audio`
Each input segment will be split into several sub-segments corresponding
to speech turn, and an attribute will be attached to each of these sub-segments
indicating the speaker of the turn.
`PASpeakerDetector` uses the `SpeakerDiarization` pipeline from
`pyannote.audio`, which performs the following steps:
- perform multi-speaker VAD with a `PyanNet` segmentation model and extract \
voiced segments ;
- compute embeddings for each voiced segment with a \
embeddings model (typically speechbrain ECAPA-TDNN) ;
- group voice segments by speakers using a clustering algorithm such as
agglomerative clustering, HMM, etc.
"""
def __init__(
self,
model: Union[str, Path],
output_label: str,
min_nb_speakers: Optional[int] = None,
max_nb_speakers: Optional[int] = None,
min_duration: float = 0.1,
device: int = -1,
segmentation_batch_size: int = 1,
embedding_batch_size: int = 1,
hf_auth_token: Optional[str] = None,
uid: Optional[str] = None,
):
"""
Parameters
----------
model:
Name (on the HuggingFace models hub) or path of a pretrained
pipeline. When a path, should point to the .yaml file containing the
pipeline configuration.
output_label:
Label of generated turn segments.
min_nb_speakers:
Minimum number of speakers expected to be found.
max_nb_speakers:
Maximum number of speakers expected to be found.
min_duration:
Minimum duration of speech segments, in seconds (short segments will
be discarded).
device:
Device to use for pytorch models. Follows the Hugging Face
convention (`-1` for cpu and device number for gpu, for instance `0`
for "cuda:0").
segmentation_batch_size:
Number of input segments in batches processed by segmentation model.
embedding_batch_size:
Number of pre-segmented audios in batches processed by embedding model.
hf_auth_token:
HuggingFace Authentication token (to access private models on the
hub)
uid:
Identifier of the detector.
"""
# Pass all arguments to super (remove self and confidential hf_auth_token)
init_args = locals()
init_args.pop("self")
init_args.pop("hf_auth_token")
super().__init__(**init_args)
self.output_label = output_label
self.min_nb_speakers = min_nb_speakers
self.max_nb_speakers = max_nb_speakers
self.min_duration = min_duration
torch_device = torch.device("cpu" if device < 0 else f"cuda:{device}")
self._pipeline = Pipeline.from_pretrained(model, use_auth_token=hf_auth_token)
if self._pipeline is None:
raise Exception(f"Could not instantiate pretrained pipeline with '{model}'")
if not isinstance(self._pipeline, SpeakerDiarization):
raise Exception(
f"'{model}' does not correspond to a SpeakerDiarization pipeline. Got"
f" object of type {type(self._pipeline)}"
)
self._pipeline.to(torch_device)
self._pipeline.segmentation_batch_size = segmentation_batch_size
self._pipeline.embedding_batch_size = embedding_batch_size
[docs] def run(self, segments: List[Segment]) -> List[Segment]:
"""Return all turn segments detected for all input `segments`.
Parameters
----------
segments:
Audio segments on which to perform diarization.
Returns
-------
List[~medkit.core.audio.Segment]:
Segments detected as containing speech activity (with speaker
attributes)
"""
return [
turn_seg
for seg in segments
for turn_seg in self._detect_turns_in_segment(seg)
]
def _detect_turns_in_segment(self, segment: Segment) -> Iterator[Segment]:
audio = segment.audio
file = {
"waveform": torch.from_numpy(audio.read()),
"sample_rate": audio.sample_rate,
}
diarization = self._pipeline.apply(
file,
min_speakers=self.min_nb_speakers,
max_speakers=self.max_nb_speakers,
)
for turn, _, speaker in diarization.itertracks(yield_label=True):
if turn.duration < self.min_duration:
continue
# trim original audio to turn start/end points
turn_audio = audio.trim_duration(turn.start, turn.end)
turn_span = Span(
start=segment.span.start + turn.start,
end=segment.span.start + turn.end,
)
speaker_attr = Attribute(label="speaker", value=speaker)
turn_segment = Segment(
label=self.output_label,
span=turn_span,
audio=turn_audio,
attrs=[speaker_attr],
)
if self._prov_tracer is not None:
self._prov_tracer.add_prov(turn_segment, self.description, [segment])
self._prov_tracer.add_prov(speaker_attr, self.description, [segment])
yield turn_segment