Source code for medkit.audio.metrics.diarization

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[metrics-diarization]`.
"""

__all__ = ["DiarizationEvaluator", "DiarizationEvaluatorResult"]

import dataclasses
import logging
from typing import Sequence

# When pyannote and spacy are both installed, a conflict might occur between the
# ujson library used by pandas (a pyannote dependency) and the ujson library used
# by srsrly (a spacy dependency), especially in docker environments.
# srsly seems to end up using the ujson library from pandas, which is older and does not
# support the escape_forward_slashes parameters, instead of its own.
# The bug seems to only happen when pandas is imported from pyannote, not if
# we import pandas manually first.
# So as a workaround, we always import pandas before importing something from pyannote
import pandas  # noqa: F401
from pyannote.core.annotation import (
    Annotation as PAAnnotation,
    Segment as PASegment,
    Timeline as PATimeline,
)
from pyannote.metrics.diarization import GreedyDiarizationErrorRate

from medkit.core.audio import AudioDocument, Segment

logger = logging.getLogger(__name__)


[docs]@dataclasses.dataclass(frozen=True) class DiarizationEvaluatorResult: """ Results returned by :class:`~.DiarizationEvaluator` Attributes ---------- der: Diarization Error Rate, combination of confusion, false alarm and missed detection confusion: Ratio of time detected as speech but attributed to a wrong speaker (over `total_speech`) false_alarm: Ratio of time corresponding to non-speech mistakenly detected as speech (over `total_speech`) missed_detection: Ratio of time corresponding to undetected speech (over `total_speech`) total_speech: Total duration of speech in the reference support: Total duration of audio """ der: float confusion: float false_alarm: float missed_detection: float total_speech: float support: float
[docs]class DiarizationEvaluator: """ Diarization Error Rate (DER) computation based on `pyannote`. The DER is the ratio of time that is not attributed correctly because of one of the following errors: - detected as non-speech when there was speech (missed detection); - detected as speech where there was none (false alarm); - attributed to the wrong speaker (confusion). This component expects as input reference documents containing the reference speech turn segments as well as corresponding predicted speech turn segments. The speech turn segments must each have a speaker attribute. Note that values of the reference and predicted speaker attributes (ie speaker labels) don't have to be the same, since they will be optimally remapped using the Hungarian algorithm. """ def __init__( self, turn_label: str = "turn", speaker_label: str = "speaker", collar: float = 0.0, ): """ Parameters ---------- turn_label: Label of the turn segments on the reference documents speaker_label: Label of the speaker attributes on the reference and predicted turn segments collar: Margin of error (in seconds) around start and end of reference segments """ self.turn_label = turn_label self.speaker_label = speaker_label self.collar = collar
[docs] def compute( self, reference: Sequence[AudioDocument], predicted: Sequence[Sequence[Segment]], ) -> DiarizationEvaluatorResult: """ Compute and return the DER for predicted speech turn segments, against reference annotated documents. Parameters ---------- reference: Reference documents containing speech turns segments with `turn_label` as label, each of them containing a speaker attribute with `speaker_label` as label. predicted: Predicted segments containing each a speaker attribute with `speaker_label` as label. This is a list of list that must be of the same length and ordering as `reference`. Returns ------- DiarizationEvaluatorResult Computed metrics """ assert len(reference) == len( predicted ), "reference and predicted must have the same length" # init pyannote metrics object into which results are accumulated pa_metric = GreedyDiarizationErrorRate(collar=self.collar) support = 0.0 for ref_doc, pred_segs in zip(reference, predicted): support += ref_doc.audio.duration ref_segs = ref_doc.anns.get(label=self.turn_label) # UEM timeline representing annotated timeline # (needed to get rid of pyannote warning) uem = PATimeline( segments=[PASegment(start=0.0, end=ref_doc.audio.duration)] ) # convert reference and predicted segment to pyannote annotation objects ref_pa_ann = self._get_pa_annotation(ref_segs) pred_pa_ann = self._get_pa_annotation(pred_segs) # pass them to pyannote metrics object pa_metric( reference=ref_pa_ann, hypothesis=pred_pa_ann, uem=uem, ) # retrieve accumulated results from pyannote metrics object, # in fractional form return DiarizationEvaluatorResult( der=abs(pa_metric), confusion=pa_metric["confusion"] / pa_metric["total"], false_alarm=pa_metric["false alarm"] / pa_metric["total"], missed_detection=pa_metric["missed detection"] / pa_metric["total"], total_speech=pa_metric["total"], support=support, )
def _get_pa_annotation(self, segments: Sequence[Segment]) -> PAAnnotation: """ Convert list of medkit speech turn segments with speaker attribute to pyannote annotation object """ pa_ann = PAAnnotation() for i, seg in enumerate(segments): # retrieve speaker speaker_attrs = seg.attrs.get(label=self.speaker_label) if not speaker_attrs: raise ValueError( f"Attribute with label '{self.speaker_label}' not found on" " turn segment" ) if len(speaker_attrs) > 1: logger.warning( f"Found several attributes with label '{self.speaker_label}'," " ignoring all but first" ) speaker = speaker_attrs[0].value # create pyannote segment object to hold boundaries # and add it to pyannote annotation object pa_seg = PASegment(seg.span.start, seg.span.end) pa_ann[pa_seg, i] = speaker return pa_ann