Source code for medkit.io.rttm
from __future__ import annotations
__all__ = ["RTTMInputConverter", "RTTMOutputConverter"]
import csv
import logging
from pathlib import Path
from typing import Any
from medkit.core import (
Attribute,
InputConverter,
OperationDescription,
OutputConverter,
ProvTracer,
generate_id,
)
from medkit.core.audio import AudioDocument, FileAudioBuffer, Segment, Span
logger = logging.getLogger(__name__)
# cf https://github.com/nryant/dscore#rttm
_RTTM_FIELDS = [
"type",
"file_id",
"channel",
"onset",
"duration",
"na_1",
"na_2",
"speaker_name",
"na_3",
"na_4",
]
[docs]
class RTTMInputConverter(InputConverter):
"""Convert Rich Transcription Time Marked (.rttm) files containing diarization
information into turn segments.
For each turn in a .rttm file, a
:class:`~medkit.core.audio.annotation.Segment` will be created, with an
associated :class:`~medkit.core.Attribute` holding the name of the turn
speaker as value. The segments can be retrieved directly or as part of an
:class:`~medkit.core.audio.document.AudioDocument` instance.
If a :class:`~medkit.core.ProvTracer` is set, provenance information will be
added for each segment and each attribute (referencing the input converter
as the operation).
"""
def __init__(
self,
turn_label: str = "turn",
speaker_label: str = "speaker",
converter_id: str | None = None,
):
"""Parameters
----------
turn_label : str, default="turn"
Label of segments representing turns in the .rttm file.
speaker_label : str, default="speaker"
Label of speaker attributes to add to each segment.
converter_id : str, optional
Identifier of the converter.
"""
if converter_id is None:
converter_id = generate_id()
self.uid = converter_id
self.turn_label = turn_label
self.speaker_label = speaker_label
self._prov_tracer: ProvTracer | None = None
@property
def description(self) -> OperationDescription:
"""Contains all the input converter init parameters."""
return OperationDescription(
uid=self.uid,
name=self.__class__.__name__,
class_name=self.__class__.__name__,
)
[docs]
def set_prov_tracer(self, prov_tracer: ProvTracer):
"""Enable provenance tracing.
Parameters
----------
prov_tracer:
The provenance tracer used to trace the provenance.
"""
self._prov_tracer = prov_tracer
[docs]
def load(
self,
rttm_dir: str | Path,
audio_dir: str | Path | None = None,
audio_ext: str = ".wav",
) -> list[AudioDocument]:
"""Load all .rttm files in a directory into a list of
:class:`~medkit.core.audio.document.AudioDocument` objects.
For each .rttm file, they must be a corresponding audio file with the
same basename, either in the same directory or in an separated audio
directory.
Parameters
----------
rttm_dir : str or Path
Directory containing the .rttm files.
audio_dir : str or Path, optional
Directory containing the audio files corresponding to the .rttm files,
if they are not in `rttm_dir`.
audio_ext : str, default=".wav"
File extension to use for audio files.
Returns
-------
list of AudioDocument
List of generated documents.
"""
rttm_dir = Path(rttm_dir)
if audio_dir is not None:
audio_dir = Path(audio_dir)
docs = []
for rttm_file in sorted(rttm_dir.glob("*.rttm")):
# corresponding audio file must have same base name with audio extension,
# either in the same directory or in audio_dir if provided
audio_file = (
(audio_dir / rttm_file.stem).with_suffix(audio_ext) if audio_dir else rttm_file.with_suffix(audio_ext)
)
doc = self.load_doc(rttm_file, audio_file)
docs.append(doc)
if len(docs) == 0:
logger.warning("No .rttm found in '%s'", rttm_dir)
return docs
[docs]
def load_doc(self, rttm_file: str | Path, audio_file: str | Path) -> AudioDocument:
"""Load a single .rttm file into an
:class:`~medkit.core.audio.document.AudioDocument`.
Parameters
----------
rttm_file : str or Path
Path to the .rttm file.
audio_file : str or Path
Path to the corresponding audio file.
Returns
-------
AudioDocument
Generated document.
"""
rttm_file = Path(rttm_file)
audio_file = Path(audio_file)
rows = self._load_rows(rttm_file)
full_audio = FileAudioBuffer(path=audio_file)
turn_segments = [self._build_turn_segment(row, full_audio) for row in rows]
doc = AudioDocument(audio=full_audio)
for turn_segment in turn_segments:
doc.anns.add(turn_segment)
return doc
[docs]
def load_turns(self, rttm_file: str | Path, audio_file: str | Path) -> list[Segment]:
"""Load a .rttm file and return a list of
:class:`~medkit.core.audio.annotation.Segment` objects.
Parameters
----------
rttm_file : str or Path
Path to the .rttm file.
audio_file : str or Path
Path to the corresponding audio file.
Returns
-------
list of Segment
Turn segments as found in the .rttm file.
"""
rttm_file = Path(rttm_file)
audio_file = Path(audio_file)
rows = self._load_rows(rttm_file)
full_audio = FileAudioBuffer(path=audio_file)
return [self._build_turn_segment(row, full_audio) for row in rows]
@staticmethod
def _load_rows(rttm_file: Path):
with Path(rttm_file).open() as fp:
csv_reader = csv.DictReader(fp, fieldnames=_RTTM_FIELDS, delimiter=" ")
rows = list(csv_reader)
file_id = rows[0]["file_id"]
if not all(r["file_id"] == file_id for r in rows):
msg = "Multi-file .rttm are not supported, all entries should have same file_id or <NA>"
raise RuntimeError(msg)
return rows
def _build_turn_segment(self, row: dict[str, Any], full_audio: FileAudioBuffer) -> Segment:
start = float(row["onset"])
end = start + float(row["duration"])
audio = full_audio.trim_duration(start, end)
segment = Segment(label=self.turn_label, span=Span(start, end), audio=audio)
speaker_attr = Attribute(label=self.speaker_label, value=row["speaker_name"])
segment.attrs.add(speaker_attr)
if self._prov_tracer is not None:
self._prov_tracer.add_prov(segment, self.description, source_data_items=[])
self._prov_tracer.add_prov(speaker_attr, self.description, source_data_items=[])
return segment
[docs]
class RTTMOutputConverter(OutputConverter):
"""Build Rich Transcription Time Marked (.rttm) files containing diarization
information from :class:`~medkit.core.audio.annotation.Segment` objects.
There must be a segment for each turn, with an associated
:class:`~medkit.core.Attribute` holding the name of the turn speaker as
value. The segments can be passed directly or as part of
:class:`~medkit.core.audio.document.AudioDocument` instances.
"""
def __init__(self, turn_label: str = "turn", speaker_label: str = "speaker"):
"""Parameters
----------
turn_label : str, default="turn"
Label of segments representing turns in the audio documents.
speaker_label : str, default="speaker"
Label of speaker attributes attached to each turn segment.
"""
super().__init__()
self.turn_label = turn_label
self.speaker_label = speaker_label
[docs]
def save(
self,
docs: list[AudioDocument],
rttm_dir: str | Path,
doc_names: list[str] | None = None,
):
"""Save :class:`~medkit.core.audio.document.AudioDocument` instances as
.rttm files in a directory.
Parameters
----------
docs : list of AudioDocument
List of audio documents to save.
rttm_dir : str or Path
Directory into which the generated .rttm files will be stored.
doc_names : list of str, optional
Optional list of names to use as basenames and file ids for the
generated .rttm files (2d column). If none provided, the document
ids will be used.
"""
rttm_dir = Path(rttm_dir)
if doc_names is not None:
if len(doc_names) != len(docs):
msg = "doc_names must have the same length as docs when provided"
raise ValueError(msg)
else:
doc_names = [doc.uid for doc in docs]
rttm_dir.mkdir(parents=True, exist_ok=True)
for doc_name, doc in zip(doc_names, docs):
rttm_file = rttm_dir / f"{doc_name}.rttm"
self.save_doc(doc, rttm_file=rttm_file, rttm_doc_id=doc_name)
[docs]
def save_doc(
self,
doc: AudioDocument,
rttm_file: str | Path,
rttm_doc_id: str | None = None,
):
"""Save a single :class:`~medkit.core.audio.document.AudioDocument` as a
.rttm file.
Parameters
----------
doc : AudioDocument
Audio document to save.
rttm_file : str or Path
Path of the generated .rttm file.
rttm_doc_id : str, optional
File uid to use for the generated .rttm file (2d column). If none
provided, the document uid will be used.
"""
rttm_file = Path(rttm_file)
if rttm_doc_id is None:
rttm_doc_id = doc.uid
turns = doc.anns.get(label=self.turn_label)
self.save_turn_segments(turns, rttm_file, rttm_doc_id)
[docs]
def save_turn_segments(
self,
turn_segments: list[Segment],
rttm_file: str | Path,
rttm_doc_id: str | None,
):
"""Save :class:`~medkit.core.audio.annotation.Segment` objects into a .rttm file.
Parameters
----------
turn_segments : list of Segment
Turn segments to save.
rttm_file : str or Path
Path of the generated .rttm file.
rttm_doc_id : str, optional
File uid to use for the generated .rttm file (2d column).
"""
rows = [self._build_rttm_row(s, rttm_doc_id) for s in turn_segments]
rows.sort(key=lambda r: r["onset"])
with Path(rttm_file).open(mode="w", encoding="utf-8") as fp:
csv_writer = csv.DictWriter(fp, fieldnames=_RTTM_FIELDS, delimiter=" ")
csv_writer.writerows(rows)
def _build_rttm_row(self, turn_segment: Segment, rttm_doc_id: str | None) -> dict[str, Any]:
speaker_attrs = turn_segment.attrs.get(label=self.speaker_label)
if len(speaker_attrs) == 0:
msg = f"Found no attribute with label '{self.speaker_label}' on turn segment"
raise RuntimeError(msg)
speaker_attr = speaker_attrs[0]
span = turn_segment.span
return {
"type": "SPEAKER",
"file_id": rttm_doc_id if rttm_doc_id is not None else "<NA>",
"channel": "1",
"onset": f"{span.start:.3f}",
"duration": f"{span.length:.3f}",
"na_1": "<NA>",
"na_2": "<NA>",
"speaker_name": speaker_attr.value,
"na_3": "<NA>",
"na_4": "<NA>",
}