Source code for medkit.audio.transcription.transcribed_text_document

from __future__ import annotations

__all__ = ["TranscribedTextDocument"]

import dataclasses
from typing import Any, Dict, List, Optional, Sequence
from typing_extensions import Self

from medkit.core import dict_conv, Attribute
from medkit.core.audio import Span as AudioSpan
from medkit.core.text import (
    TextDocument,
    Span as TextSpan,
    AnySpan as AnyTextSpan,
    TextAnnotation,
    Segment as TextSegment,
    span_utils as text_span_utils,
)


[docs]@dataclasses.dataclass(init=False) class TranscribedTextDocument(TextDocument): """Subclass for :class:`~medkit.core.text.document.TextDocument` instances generated by audio transcription. Attributes ---------- uid: Document identifier. text: The full transcribed text. text_spans_to_audio_spans: Mapping between text characters spans in this document and corresponding audio spans in the original audio. audio_doc_id: Id of the original :class:`~medkit.core.audio.document.AudioDocument` that was transcribed, if known. anns: Annotations of the document. attrs: Attributes of the document. metadata: Document metadata. raw_segment: Auto-generated segment containing the raw full transcribed text. """ text_spans_to_audio_spans: Dict[TextSpan, AudioSpan] audio_doc_id: Optional[str] def __init__( self, text: str, text_spans_to_audio_spans: Dict[TextSpan, AudioSpan], audio_doc_id: Optional[str], anns: Optional[Sequence[TextAnnotation]] = None, attrs: Optional[Sequence[Attribute]] = None, metadata: Optional[Dict[str, Any]] = None, uid: Optional[str] = None, ): assert all(s.end <= len(text) for s in text_spans_to_audio_spans) super().__init__(text=text, anns=anns, attrs=attrs, metadata=metadata, uid=uid) self.audio_doc_id = audio_doc_id self.text_spans_to_audio_spans = text_spans_to_audio_spans
[docs] def get_containing_audio_spans( self, text_ann_spans: List[AnyTextSpan] ) -> List[AudioSpan]: """Return the audio spans used to transcribe the text referenced by a text annotation. For instance, if the audio ranging from 1.0 to 20.0 seconds is transcribed to some text ranging from character 10 to 56 in the transcribed document, and then a text annotation is created referencing the span 15 to 25, then the containing audio span will be the one ranging from 1.0 to 20.0 seconds. Note that some text annotations maybe be contained in more that one audio spans. Parameters ---------- text_ann_spans: Text spans of a text annotation referencing some characters in the transcribed document. Returns ------- List[AudioSpan] Audio spans used to transcribe the text referenced by the spans of `text_ann`. """ ann_text_spans = text_span_utils.normalize_spans(text_ann_spans) # TODO: use interval tree instead of nested iteration audio_spans = [ audio_span for ann_text_span in ann_text_spans for text_span, audio_span in self.text_spans_to_audio_spans.items() if text_span.overlaps(ann_text_span) ] return audio_spans
def to_dict(self, with_anns: bool = True) -> Dict[str, Any]: text_spans = [s.to_dict() for s in self.text_spans_to_audio_spans] audio_spans = [s.to_dict() for s in self.text_spans_to_audio_spans.values()] doc_dict = dict( uid=self.uid, text=self.text, metadata=self.metadata, text_spans=text_spans, audio_spans=audio_spans, audio_doc_id=self.audio_doc_id, ) if with_anns: doc_dict["anns"] = [a.to_dict() for a in self.anns] if self.attrs: doc_dict["attrs"] = [a.to_dict() for a in self.attrs] dict_conv.add_class_name_to_data_dict(self, doc_dict) return doc_dict
[docs] @classmethod def from_dict(cls, doc_dict: Dict[str, Any]) -> Self: """ Create a `TranscribedTextDocument` from a dict Parameters ---------- doc_dict: A dictionary from a serialized `TranscribedTextDocument` as generated by to_dict() """ text_spans = [TextSpan.from_dict(s) for s in doc_dict["text_spans"]] audio_spans = [AudioSpan.from_dict(s) for s in doc_dict["audio_spans"]] text_spans_to_audio_spans = dict(zip(text_spans, audio_spans)) anns = [TextSegment.from_dict(a) for a in doc_dict["anns"]] attrs = [Attribute.from_dict(a) for a in doc_dict.get("attrs", [])] return cls( uid=doc_dict["uid"], text=doc_dict["text"], text_spans_to_audio_spans=text_spans_to_audio_spans, audio_doc_id=doc_dict["audio_doc_id"], anns=anns, attrs=attrs, metadata=doc_dict["metadata"], )