from __future__ import annotations
__all__ = ["TranscribedTextDocument"]
import dataclasses
from typing import Any, Dict, List, Optional, Sequence
from typing_extensions import Self
from medkit.core import dict_conv, Attribute
from medkit.core.audio import Span as AudioSpan
from medkit.core.text import (
TextDocument,
Span as TextSpan,
AnySpan as AnyTextSpan,
TextAnnotation,
Segment as TextSegment,
span_utils as text_span_utils,
)
[docs]@dataclasses.dataclass(init=False)
class TranscribedTextDocument(TextDocument):
"""Subclass for :class:`~medkit.core.text.document.TextDocument` instances generated
by audio transcription.
Attributes
----------
uid:
Document identifier.
text:
The full transcribed text.
text_spans_to_audio_spans:
Mapping between text characters spans in this document and
corresponding audio spans in the original audio.
audio_doc_id:
Id of the original
:class:`~medkit.core.audio.document.AudioDocument` that was
transcribed, if known.
anns:
Annotations of the document.
attrs:
Attributes of the document.
metadata:
Document metadata.
raw_segment:
Auto-generated segment containing the raw full transcribed text.
"""
text_spans_to_audio_spans: Dict[TextSpan, AudioSpan]
audio_doc_id: Optional[str]
def __init__(
self,
text: str,
text_spans_to_audio_spans: Dict[TextSpan, AudioSpan],
audio_doc_id: Optional[str],
anns: Optional[Sequence[TextAnnotation]] = None,
attrs: Optional[Sequence[Attribute]] = None,
metadata: Optional[Dict[str, Any]] = None,
uid: Optional[str] = None,
):
assert all(s.end <= len(text) for s in text_spans_to_audio_spans)
super().__init__(text=text, anns=anns, attrs=attrs, metadata=metadata, uid=uid)
self.audio_doc_id = audio_doc_id
self.text_spans_to_audio_spans = text_spans_to_audio_spans
[docs] def get_containing_audio_spans(
self, text_ann_spans: List[AnyTextSpan]
) -> List[AudioSpan]:
"""Return the audio spans used to transcribe the text referenced by a text
annotation.
For instance, if the audio ranging from 1.0 to 20.0 seconds is
transcribed to some text ranging from character 10 to 56 in the
transcribed document, and then a text annotation is created referencing
the span 15 to 25, then the containing audio span will be the one ranging
from 1.0 to 20.0 seconds.
Note that some text annotations maybe be contained in more that one
audio spans.
Parameters
----------
text_ann_spans:
Text spans of a text annotation referencing some characters in the
transcribed document.
Returns
-------
List[AudioSpan]
Audio spans used to transcribe the text referenced by the spans of `text_ann`.
"""
ann_text_spans = text_span_utils.normalize_spans(text_ann_spans)
# TODO: use interval tree instead of nested iteration
audio_spans = [
audio_span
for ann_text_span in ann_text_spans
for text_span, audio_span in self.text_spans_to_audio_spans.items()
if text_span.overlaps(ann_text_span)
]
return audio_spans
def to_dict(self, with_anns: bool = True) -> Dict[str, Any]:
text_spans = [s.to_dict() for s in self.text_spans_to_audio_spans]
audio_spans = [s.to_dict() for s in self.text_spans_to_audio_spans.values()]
doc_dict = dict(
uid=self.uid,
text=self.text,
metadata=self.metadata,
text_spans=text_spans,
audio_spans=audio_spans,
audio_doc_id=self.audio_doc_id,
)
if with_anns:
doc_dict["anns"] = [a.to_dict() for a in self.anns]
if self.attrs:
doc_dict["attrs"] = [a.to_dict() for a in self.attrs]
dict_conv.add_class_name_to_data_dict(self, doc_dict)
return doc_dict
[docs] @classmethod
def from_dict(cls, doc_dict: Dict[str, Any]) -> Self:
"""
Create a `TranscribedTextDocument` from a dict
Parameters
----------
doc_dict:
A dictionary from a serialized `TranscribedTextDocument` as generated by to_dict()
"""
text_spans = [TextSpan.from_dict(s) for s in doc_dict["text_spans"]]
audio_spans = [AudioSpan.from_dict(s) for s in doc_dict["audio_spans"]]
text_spans_to_audio_spans = dict(zip(text_spans, audio_spans))
anns = [TextSegment.from_dict(a) for a in doc_dict["anns"]]
attrs = [Attribute.from_dict(a) for a in doc_dict.get("attrs", [])]
return cls(
uid=doc_dict["uid"],
text=doc_dict["text"],
text_spans_to_audio_spans=text_spans_to_audio_spans,
audio_doc_id=doc_dict["audio_doc_id"],
anns=anns,
attrs=attrs,
metadata=doc_dict["metadata"],
)