Source code for medkit.audio.metrics.transcription

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[metrics-transcription]`.
"""

__all__ = ["TranscriptionEvaluator", "TranscriptionEvaluatorResult"]

import dataclasses
import functools
import logging
import string
from typing import List, Sequence

from speechbrain.utils.metric_stats import ErrorRateStats

from medkit.core.audio import AudioDocument, Segment
from medkit.text.utils.decoding import get_ascii_from_unicode

logger = logging.getLogger(__name__)


[docs]@dataclasses.dataclass(frozen=True) class TranscriptionEvaluatorResult: """ Results returned by :class:`~.TranscriptionEvaluator` Attributes ---------- wer: Word Error Rate, combination of word insertions, deletions and substitutions word_insertions: Ratio of extra words in prediction (over `word_support`) word_deletions: Ratio of missing words in prediction (over `word_support`) word_substitutions Ratio of replaced words in prediction (over `word_support`) word_support: Total number of words cer: Character Error Rate, same as `wer` but at character level char_insertions: Identical to `word_insertions` but at character level char_deletions: Identical to `word_deletions` but at character level char_substitutions: Identical to `word_substitutions` but at character level char_support: Total number of characters (not including whitespaces, post punctuation removal and unicode replacement) """ wer: float word_insertions: float word_deletions: float word_substitutions: float word_support: int cer: float char_insertions: float char_deletions: float char_substitutions: float char_support: int
[docs]class TranscriptionEvaluator: """ Word Error Rate (WER) and Character Error Rate (CER) computation based on `speechbrain`. The WER is the ratio of predictions errors at the word level, taking into accounts: - words present in the reference transcription but missing from the prediction; - extra predicted words not present in the reference; - reference words mistakenly replaced by other words in the prediction. The CER is identical to the WER but computed at the character level rather than at the word level. This component expects as input reference documents containing speech segments with reference transcription attributes, as well as corresponding speech segments with predicted transcription attributes. """ def __init__( self, speech_label: str = "speech", transcription_label: str = "transcription", case_sensitive: bool = False, remove_punctuation: bool = True, replace_unicode: bool = False, ): """ Parameters ---------- speech_label: Label of the speech segments on the reference documents transcription_label: Label of the transcription attributes on the reference and predicted speech segments case_sensitive: Whether to take case into consideration when comparing reference and prediction remove_punctuation: If True, punctuation in reference and predictions is removed before comparing (based on `string.punctuation`) replace_unicode: If True, special unicode characters in reference and predictions are replaced by their closest ASCII characters (when possible) before comparing """ self.speech_label = speech_label self.transcription_label = transcription_label self.case_sensitive = case_sensitive self.remove_punctuation = remove_punctuation self.replace_unicode = replace_unicode
[docs] def compute( self, reference: Sequence[AudioDocument], predicted: Sequence[Sequence[Segment]], ) -> TranscriptionEvaluatorResult: """ Compute and return the WER and CER for predicted transcription attributes, against reference annotated documents. Parameters ---------- reference: Reference documents containing speech segments with `speech_label` as label, each of them containing a transcription attribute with `transcription_label` as label. predicted: Predicted segments containing each a transcription attribute with `transcription_label` as label. This is a list of list that must be of the same length and ordering as `reference`. Returns ------- TranscriptionEvaluatorResult Computed metrics """ assert len(reference) == len( predicted ), "reference and predicted must have the same length" sb_wer_metric = ErrorRateStats() sb_cer_metric = ErrorRateStats(split_tokens=True) for i, (ref_doc, pred_segs) in enumerate(zip(reference, predicted)): ref_segs = ref_doc.anns.get(label=self.speech_label) ref_words = self._convert_speech_segs_to_words(ref_segs) pred_words = self._convert_speech_segs_to_words(pred_segs) sb_wer_metric.append(ids=[i], predict=[pred_words], target=[ref_words]) sb_cer_metric.append(ids=[i], predict=[pred_words], target=[ref_words]) wer_results = sb_wer_metric.summarize() nb_words = wer_results["num_scored_tokens"] cer_results = sb_cer_metric.summarize() nb_chars = cer_results["num_scored_tokens"] return TranscriptionEvaluatorResult( wer=wer_results["num_edits"] / nb_words, word_insertions=wer_results["insertions"] / nb_words, word_deletions=wer_results["deletions"] / nb_words, word_substitutions=wer_results["substitutions"] / nb_words, word_support=nb_words, cer=cer_results["num_edits"] / nb_chars, char_insertions=cer_results["insertions"] / nb_chars, char_deletions=cer_results["deletions"] / nb_chars, char_substitutions=cer_results["substitutions"] / nb_chars, char_support=nb_chars, )
def _convert_speech_segs_to_words(self, segments: Sequence[Segment]) -> List[str]: """ Convert list of speech segments with transcription attribute to list of words that can be passed to speechbrain metrics objects """ # get values of all transcription attributes and concatenate them into # one big string representing the transcription of the whole document # sort segments by time to concatenate in correct order segments = sorted(segments, key=lambda s: s.span) texts = [] for seg in segments: # retrieve transcription transcription_attrs = seg.attrs.get(label=self.transcription_label) if not transcription_attrs: raise ValueError( f"Attribute with label '{self.transcription_label}' not found on" " speech segment" ) if len(transcription_attrs) > 1: logger.warning( f"Found several attributes with label '{self.transcription_label}'," " ignoring all but first" ) transcription = transcription_attrs[0].value texts.append(transcription) text = " ".join(texts) # apply pre-WER transforms if not self.case_sensitive: text = text.lower() if self.remove_punctuation: punct_trans_table = _get_punctation_translation_table() text = text.translate(punct_trans_table) if self.replace_unicode: text = get_ascii_from_unicode(text, logger=logger) # split into words words = [w for w in text.split(" ") if w] return words
@functools.lru_cache def _get_punctation_translation_table(): """ Return a translation table mapping all punctuations chars to a single space, that can be used with `str.translate()` """ return str.maketrans(dict.fromkeys(string.punctuation, " "))