Source code for medkit.io.srt
__all__ = ["SRTInputConverter", "SRTOutputConverter"]
import logging
from pathlib import Path
from typing import List, Optional, Union
import pysrt
from medkit.core import (
generate_id,
InputConverter,
OutputConverter,
OperationDescription,
ProvTracer,
Attribute,
)
from medkit.core.audio import AudioDocument, Segment, Span, FileAudioBuffer
logger = logging.getLogger(__name__)
[docs]class SRTInputConverter(InputConverter):
"""
Convert .srt files containing transcription information into turn segments
with transcription attributes.
For each turn in a .srt file, a
:class:`~medkit.core.audio.annotation.Segment` will be created, with an
associated :class:`~medkit.core.Attribute` holding the transcribed text as
value. The segments can be retrieved directly or as part of an
:class:`~medkit.core.audio.document.AudioDocument` instance.
If a :class:`~medkit.core.ProvTracer` is set, provenance information will be
added for each segment and each attribute (referencing the input converter
as the operation).
"""
def __init__(
self,
turn_segment_label: str = "turn",
transcription_attr_label: str = "transcribed_text",
converter_id: Optional[str] = None,
):
"""
Parameters
----------
turn_segment_label:
Label to use for segments representing turns in the .srt file.
transcription_attr_label:
Label to use for segments attributes containing the transcribed text.
converter_id:
Identifier of the converter.
"""
if converter_id is None:
converter_id = generate_id()
self.uid = converter_id
self.turn_segment_label = turn_segment_label
self.transcription_attr_label = transcription_attr_label
self._prov_tracer: Optional[ProvTracer] = None
@property
def description(self) -> OperationDescription:
"""Contains all the input converter init parameters."""
return OperationDescription(
uid=self.uid,
name=self.__class__.__name__,
class_name=self.__class__.__name__,
config={
"turn_segment_label": self.turn_segment_label,
"transcription_attr_label": self.transcription_attr_label,
},
)
[docs] def set_prov_tracer(self, prov_tracer: ProvTracer):
"""Enable provenance tracing.
Parameters
----------
prov_tracer:
The provenance tracer used to trace the provenance.
"""
self._prov_tracer = prov_tracer
[docs] def load(
self,
srt_dir: Union[str, Path],
audio_dir: Optional[Union[str, Path]] = None,
audio_ext: str = ".wav",
) -> List[AudioDocument]:
"""
Load all .srt files in a directory into a list of
:class:`~medkit.core.audio.document.AudioDocument` objects.
For each .srt file, they must be a corresponding audio file with the
same basename, either in the same directory or in an separated audio
directory.
Parameters
----------
srt_dir:
Directory containing the .srt files.
audio_dir:
Directory containing the audio files corresponding to the .srt files,
if they are not in `srt_dir`.
audio_ext:
File extension to use for audio files.
Returns
-------
List[AudioDocument]
List of generated documents.
"""
srt_dir = Path(srt_dir)
audio_dir = Path(audio_dir) if audio_dir else None
docs = []
for srt_file in sorted(srt_dir.glob("*.srt")):
# corresponding audio file must have same base name with audio extension,
# either in the same directory or in audio_dir if provided
if audio_dir:
audio_file = (audio_dir / srt_file.stem).with_suffix(audio_ext)
else:
audio_file = srt_file.with_suffix(audio_ext)
doc = self.load_doc(srt_file, audio_file)
docs.append(doc)
if len(docs) == 0:
logger.warning(f"No .srt found in '{srt_dir}'")
return docs
[docs] def load_doc(
self, srt_file: Union[str, Path], audio_file: Union[str, Path]
) -> AudioDocument:
"""Load a single .srt file into an
:class:`~medkit.core.audio.document.AudioDocument` containing
turn segments with transcription attributes.
Parameters
----------
srt_file:
Path to the .srt file.
audio_file:
Path to the corresponding audio file.
Returns
-------
AudioDocument:
Generated document.
"""
audio_file = Path(audio_file)
srt_items = pysrt.open(str(srt_file))
full_audio = FileAudioBuffer(path=audio_file)
segments = [self._build_segment(srt_item, full_audio) for srt_item in srt_items]
doc = AudioDocument(audio=full_audio)
for segment in segments:
doc.anns.add(segment)
return doc
[docs] def load_segments(
self, srt_file: Union[str, Path], audio_file: Union[str, Path]
) -> List[Segment]:
"""Load a .srt file and return a list of
:class:`~medkit.core.audio.annotation.Segment` objects corresponding to
turns, with transcription attributes.
Parameters
----------
srt_file:
Path to the .srt file.
audio_file:
Path to the corresponding audio file.
Returns
-------
List[:class:`~medkit.core.audio.annotation.Segment`]:
Turn segments as found in the .srt file, with transcription
attributes attached.
"""
audio_file = Path(audio_file)
srt_items = pysrt.open(str(srt_file))
full_audio = FileAudioBuffer(path=audio_file)
segments = [self._build_segment(srt_item, full_audio) for srt_item in srt_items]
return segments
def _build_segment(
self, srt_item: pysrt.SubRipItem, full_audio: FileAudioBuffer
) -> Segment:
# milliseconds to seconds
start = srt_item.start.ordinal / 1000
end = srt_item.end.ordinal / 1000
audio = full_audio.trim_duration(start, end)
segment = Segment(
label=self.turn_segment_label, span=Span(start, end), audio=audio
)
transcription_attr = Attribute(
label=self.transcription_attr_label, value=srt_item.text
)
segment.attrs.add(transcription_attr)
if self._prov_tracer is not None:
self._prov_tracer.add_prov(segment, self.description, source_data_items=[])
self._prov_tracer.add_prov(
transcription_attr, self.description, source_data_items=[]
)
return segment
[docs]class SRTOutputConverter(OutputConverter):
"""
Build .srt files containing transcription information from
:class:`~medkit.core.audio.annotation.Segment` objects.
There must be a segment for each turn, with an associated
:class:`~medkit.core.Attribute` holding the transcribed text as
value. The segments can be passed directly or as part of
:class:`~medkit.core.audio.document.AudioDocument` instances.
"""
def __init__(
self,
segment_turn_label: str = "turn",
transcription_attr_label: str = "transcribed_text",
):
"""
Parameters
----------
segment_turn_label:
Label of segments representing turns in the audio documents.
transcription_attr_label:
Label of segments attributes containing the transcribed text.
"""
super().__init__()
self.segment_turn_label = segment_turn_label
self.transcription_attr_label = transcription_attr_label
[docs] def save(
self,
docs: List[AudioDocument],
srt_dir: Union[str, Path],
doc_names: Optional[List[str]] = None,
):
"""Save :class:`~medkit.core.audio.document.AudioDocument` instances as
.srt files in a directory.
Parameters
----------
docs:
List of audio documents to save.
str_dir:
Directory into which the generated .str files will be stored.
doc_names:
Optional list of names to use as basenames for the generated .srt
files.
"""
srt_dir = Path(srt_dir)
if doc_names is not None:
if len(doc_names) != len(docs):
raise ValueError(
"doc_names must have the same length as docs when provided"
)
else:
doc_names = [doc.uid for doc in docs]
srt_dir.mkdir(parents=True, exist_ok=True)
for doc_name, doc in zip(doc_names, docs):
srt_file = srt_dir / f"{doc_name}.srt"
self.save_doc(doc, srt_file=srt_file)
[docs] def save_doc(
self,
doc: AudioDocument,
srt_file: Union[str, Path],
):
"""Save a single :class:`~medkit.core.audio.document.AudioDocument` as a
.srt file.
Parameters
----------
doc:
Audio document to save.
srt_file:
Path of the generated .srt file.
"""
srt_file = Path(srt_file)
segments = doc.anns.get(label=self.segment_turn_label)
self.save_segments(segments, srt_file)
[docs] def save_segments(self, segments: List[Segment], srt_file: Union[str, Path]):
"""Save :class:`~medkit.core.audio.annotation.Segment` objects representing
turns into a .srt file.
Parameters
----------
segments:
Turn segments to save.
srt_file:
Path of the generated .srt file.
"""
srt_items = pysrt.SubRipFile(path=str(srt_file))
for i, segment in enumerate(segments):
transcription_attr = segment.attrs.get(label=self.transcription_attr_label)[
0
]
srt_item = pysrt.SubRipItem(
index=i,
start=pysrt.SubRipTime(seconds=segment.span.start),
end=pysrt.SubRipTime(seconds=segment.span.end),
text=transcription_attr.value,
)
srt_items.append(srt_item)
srt_items.save()