Source code for medkit.io.srt
"""This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[srt-io-convert]`.
"""
from __future__ import annotations
__all__ = ["SRTInputConverter", "SRTOutputConverter"]
import logging
from pathlib import Path
import pysrt
from medkit.core import (
Attribute,
InputConverter,
OperationDescription,
OutputConverter,
ProvTracer,
generate_id,
)
from medkit.core.audio import AudioDocument, FileAudioBuffer, Segment, Span
logger = logging.getLogger(__name__)
[docs]
class SRTInputConverter(InputConverter):
"""Convert .srt files containing transcription information into turn segments
with transcription attributes.
For each turn in a .srt file, a
:class:`~medkit.core.audio.annotation.Segment` will be created, with an
associated :class:`~medkit.core.Attribute` holding the transcribed text as
value. The segments can be retrieved directly or as part of an
:class:`~medkit.core.audio.document.AudioDocument` instance.
If a :class:`~medkit.core.ProvTracer` is set, provenance information will be
added for each segment and each attribute (referencing the input converter
as the operation).
"""
def __init__(
self,
turn_segment_label: str = "turn",
transcription_attr_label: str = "transcribed_text",
converter_id: str | None = None,
):
"""Parameters
----------
turn_segment_label : str, default="turn"
Label to use for segments representing turns in the .srt file.
transcription_attr_label : str, default="transcribed_text"
Label to use for segments attributes containing the transcribed text.
converter_id : str, optional
Identifier of the converter.
"""
if converter_id is None:
converter_id = generate_id()
self.uid = converter_id
self.turn_segment_label = turn_segment_label
self.transcription_attr_label = transcription_attr_label
self._prov_tracer: ProvTracer | None = None
@property
def description(self) -> OperationDescription:
"""Contains all the input converter init parameters."""
return OperationDescription(
uid=self.uid,
name=self.__class__.__name__,
class_name=self.__class__.__name__,
config={
"turn_segment_label": self.turn_segment_label,
"transcription_attr_label": self.transcription_attr_label,
},
)
[docs]
def set_prov_tracer(self, prov_tracer: ProvTracer):
"""Enable provenance tracing.
Parameters
----------
prov_tracer : ProvTracer
The provenance tracer used to trace the provenance.
"""
self._prov_tracer = prov_tracer
[docs]
def load(
self,
srt_dir: str | Path,
audio_dir: str | Path | None = None,
audio_ext: str = ".wav",
) -> list[AudioDocument]:
"""Load all .srt files in a directory into a list of
:class:`~medkit.core.audio.document.AudioDocument` objects.
For each .srt file, they must be a corresponding audio file with the
same basename, either in the same directory or in an separated audio
directory.
Parameters
----------
srt_dir : str or Path
Directory containing the .srt files.
audio_dir : str or Path, optional
Directory containing the audio files corresponding to the .srt files,
if they are not in `srt_dir`.
audio_ext : str, default=".wav"
File extension to use for audio files.
Returns
-------
list of AudioDocument
List of generated documents.
"""
srt_dir = Path(srt_dir)
audio_dir = Path(audio_dir) if audio_dir else None
docs = []
for srt_file in sorted(srt_dir.glob("*.srt")):
# corresponding audio file must have same base name with audio extension,
# either in the same directory or in audio_dir if provided
audio_file = (
(audio_dir / srt_file.stem).with_suffix(audio_ext) if audio_dir else srt_file.with_suffix(audio_ext)
)
doc = self.load_doc(srt_file, audio_file)
docs.append(doc)
if len(docs) == 0:
logger.warning("No .srt found in '%s'", srt_dir)
return docs
[docs]
def load_doc(self, srt_file: str | Path, audio_file: str | Path) -> AudioDocument:
"""Load a single .srt file into an
:class:`~medkit.core.audio.document.AudioDocument` containing
turn segments with transcription attributes.
Parameters
----------
srt_file : str or Path
Path to the .srt file.
audio_file : str or Path
Path to the corresponding audio file.
Returns
-------
AudioDocument
Generated document.
"""
audio_file = Path(audio_file)
srt_items = pysrt.open(str(srt_file))
full_audio = FileAudioBuffer(path=audio_file)
segments = [self._build_segment(srt_item, full_audio) for srt_item in srt_items]
doc = AudioDocument(audio=full_audio)
for segment in segments:
doc.anns.add(segment)
return doc
[docs]
def load_segments(self, srt_file: str | Path, audio_file: str | Path) -> list[Segment]:
"""Load a .srt file and return a list of
:class:`~medkit.core.audio.annotation.Segment` objects corresponding to
turns, with transcription attributes.
Parameters
----------
srt_file : str or Path
Path to the .srt file.
audio_file : str or Path
Path to the corresponding audio file.
Returns
-------
list of Segment
Turn segments as found in the .srt file, with transcription
attributes attached.
"""
audio_file = Path(audio_file)
srt_items = pysrt.open(str(srt_file))
full_audio = FileAudioBuffer(path=audio_file)
return [self._build_segment(srt_item, full_audio) for srt_item in srt_items]
def _build_segment(self, srt_item: pysrt.SubRipItem, full_audio: FileAudioBuffer) -> Segment:
# milliseconds to seconds
start = srt_item.start.ordinal / 1000
end = srt_item.end.ordinal / 1000
audio = full_audio.trim_duration(start, end)
segment = Segment(label=self.turn_segment_label, span=Span(start, end), audio=audio)
transcription_attr = Attribute(label=self.transcription_attr_label, value=srt_item.text)
segment.attrs.add(transcription_attr)
if self._prov_tracer is not None:
self._prov_tracer.add_prov(segment, self.description, source_data_items=[])
self._prov_tracer.add_prov(transcription_attr, self.description, source_data_items=[])
return segment
[docs]
class SRTOutputConverter(OutputConverter):
"""Build .srt files containing transcription information from
:class:`~medkit.core.audio.annotation.Segment` objects.
There must be a segment for each turn, with an associated
:class:`~medkit.core.Attribute` holding the transcribed text as
value. The segments can be passed directly or as part of
:class:`~medkit.core.audio.document.AudioDocument` instances.
"""
def __init__(
self,
segment_turn_label: str = "turn",
transcription_attr_label: str = "transcribed_text",
):
"""Parameters
----------
segment_turn_label : str, default="turn"
Label of segments representing turns in the audio documents.
transcription_attr_label : str, default="transcribed_text"
Label of segments attributes containing the transcribed text.
"""
super().__init__()
self.segment_turn_label = segment_turn_label
self.transcription_attr_label = transcription_attr_label
[docs]
def save(
self,
docs: list[AudioDocument],
srt_dir: str | Path,
doc_names: list[str] | None = None,
):
"""Save :class:`~medkit.core.audio.document.AudioDocument` instances as
.srt files in a directory.
Parameters
----------
docs : list of AudioDocument
List of audio documents to save.
srt_dir : str or Path
Directory into which the generated .str files will be stored.
doc_names : list of str, optional
Optional list of names to use as basenames for the generated .srt
files.
"""
srt_dir = Path(srt_dir)
if doc_names is not None:
if len(doc_names) != len(docs):
msg = "doc_names must have the same length as docs when provided"
raise ValueError(msg)
else:
doc_names = [doc.uid for doc in docs]
srt_dir.mkdir(parents=True, exist_ok=True)
for doc_name, doc in zip(doc_names, docs):
srt_file = srt_dir / f"{doc_name}.srt"
self.save_doc(doc, srt_file=srt_file)
[docs]
def save_doc(
self,
doc: AudioDocument,
srt_file: str | Path,
):
"""Save a single :class:`~medkit.core.audio.document.AudioDocument` as a
.srt file.
Parameters
----------
doc : AudioDocument
Audio document to save.
srt_file : str or Path
Path of the generated .srt file.
"""
srt_file = Path(srt_file)
segments = doc.anns.get(label=self.segment_turn_label)
self.save_segments(segments, srt_file)
[docs]
def save_segments(self, segments: list[Segment], srt_file: str | Path):
"""Save :class:`~medkit.core.audio.annotation.Segment` objects representing
turns into a .srt file.
Parameters
----------
segments : list of Segment
Turn segments to save.
srt_file : str or Path
Path of the generated .srt file.
"""
srt_items = pysrt.SubRipFile(path=str(srt_file))
for i, segment in enumerate(segments):
transcription_attr = segment.attrs.get(label=self.transcription_attr_label)[0]
srt_item = pysrt.SubRipItem(
index=i,
start=pysrt.SubRipTime(seconds=segment.span.start),
end=pysrt.SubRipTime(seconds=segment.span.end),
text=transcription_attr.value,
)
srt_items.append(srt_item)
srt_items.save()