Source code for medkit.audio.segmentation.webrtc_voice_detector

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[webrtc-voice-detector]`.
"""

__all__ = ["WebRTCVoiceDetector"]

import collections
from typing import Iterator, List, Optional
from typing_extensions import Literal

import numpy as np
import webrtcvad

from medkit.core.audio import SegmentationOperation, Segment, Span


_SUPPORTED_SAMPLE_RATES = {8000, 16000, 32000, 48000}


[docs]class WebRTCVoiceDetector(SegmentationOperation): """Voice Activity Detection operation relying on the `webrtcvad` package. Per-frame VAD results of `webrtcvad` are aggregated with a switch algorithm considering the percentage of speech/non-speech frames in a wider sliding window. Input segments must be mono at 8kHZ, 16kHz, 32kHz or 48Khz. """ def __init__( self, output_label: str, aggressiveness: Literal[0, 1, 2, 3] = 2, frame_duration: Literal[10, 20, 30] = 30, nb_frames_in_window: int = 10, switch_ratio: float = 0.9, uid: Optional[str] = None, ): """ Parameters ---------- output_label: Label of output speech segments. aggressiveness: Aggressiveness param passed to `webrtcvad` (the higher, the more likely to detect speech). frame_duration: Duration in milliseconds of frames passed to `webrtcvad`. nb_frames_in_window: Number of frames in the sliding window used when aggregating per-frame VAD results. switch_ratio: Percentage of speech/non-speech frames required to switch the window speech state when aggregating per-frame VAD results. uid: Identifier of the detector. """ # Pass all arguments to super (remove self) init_args = locals() init_args.pop("self") super().__init__(**init_args) self.output_label = output_label self.aggressiveness = aggressiveness self.frame_duration = frame_duration self.nb_frames_in_window = nb_frames_in_window self.switch_ratio = switch_ratio self._vad = webrtcvad.Vad(aggressiveness)
[docs] def run(self, segments: List[Segment]) -> List[Segment]: """Return all speech segments detected for all input `segments`. Parameters ---------- segments: Audio segments on which to perform VAD. Returns ------- List[~medkit.core.audio.Segment]: Segments detected as containing speech activity. """ return [ voice_seg for seg in segments for voice_seg in self._detect_activity_in_segment(seg) ]
def _detect_activity_in_segment(self, segment: Segment) -> Iterator[Segment]: audio = segment.audio if audio.nb_channels > 1: raise RuntimeError( f"Segment with identifier {segment.uid} has multi-channel audio, which" " is not supported" ) if audio.sample_rate not in _SUPPORTED_SAMPLE_RATES: raise RuntimeError( f"Segment with identifier {segment.uid} has non-supported sample rate" f" {audio.sample_rate}" ) sample_rate = audio.sample_rate nb_samples = audio.nb_samples signal = audio.read() # convert float32 signal to int16 (required by webrtcvad) int_signal = (signal * 32767).astype(np.int16) frame_length = int(self.frame_duration * sample_rate / 1000) # zero-pad tail and split in frames padding_length = frame_length - nb_samples % frame_length padding = np.zeros((1, padding_length), dtype=np.int16) int_signal = np.concatenate((int_signal, padding), axis=1) frames = np.split(int_signal, int_signal.shape[1] / frame_length, axis=1) # run vad speech_frame_indices = self._get_aggregated_vad(frames, sample_rate) # generate segments for start_frame_index, end_frame_index in speech_frame_indices: # trim original audio start = start_frame_index * frame_length end = end_frame_index * frame_length end = min(end, nb_samples) voiced_audio = audio.trim(start, end) start_time = start / sample_rate end_time = end / sample_rate # build corresponding span voiced_span = Span( segment.span.start + start_time, segment.span.start + end_time ) voiced_segment = Segment( label=self.output_label, span=voiced_span, audio=voiced_audio ) if self._prov_tracer is not None: self._prov_tracer.add_prov(voiced_segment, self.description, [segment]) yield voiced_segment # from https://github.com/wiseman/py-webrtcvad/blob/master/example.py def _get_aggregated_vad(self, frames, sample_rate): """Return index ranges of voiced frames using webrtcvad""" # deque for our sliding window ring buffer window_ring_buffer = collections.deque(maxlen=self.nb_frames_in_window) # we have two states: SPEECH and NONSPEECH (we start in NONSPEECH) is_speech = False speech_frame_ranges = [] start_index = None for i, frame in enumerate(frames): # compute speech state for frame and push it to ring buffer of frames in window frame_bytes = frame.tobytes() frame_is_speech = self._vad.is_speech(frame_bytes, sample_rate) window_ring_buffer.append((i, frame_is_speech)) if not is_speech: nb_speech_frames = sum( frame_is_speech for _, frame_is_speech in window_ring_buffer ) # if we are NONSPEECH and more than 90% of the frames in # the ring buffer are speech frames, then enter the # SPEECH state if nb_speech_frames > self.switch_ratio * window_ring_buffer.maxlen: is_speech = True # all frames in the ring buffer are retrospectively considered as SPEECH start_index, _ = window_ring_buffer[0] # all upcoming frames will also be considered SPEECH until we enter # NONSPEECH state window_ring_buffer.clear() else: nb_non_speech_frames = sum( not frame_is_speech for _, frame_is_speech in window_ring_buffer ) # if more than 90% of the frames in the ring buffer are # non-speech, then enter NONSPEECH if nb_non_speech_frames > self.switch_ratio * window_ring_buffer.maxlen: is_speech = False # push indices of the SPEECH range that just ended end_index, _ = window_ring_buffer[-1] speech_frame_ranges.append((start_index, end_index)) window_ring_buffer.clear() # handle trail if is_speech and window_ring_buffer: end_index, _ = window_ring_buffer[-1] speech_frame_ranges.append((start_index, end_index)) return speech_frame_ranges