Source code for medkit.core.text.annotation

from __future__ import annotations

__all__ = ["TextAnnotation", "Segment", "Entity", "Relation"]

import abc
import dataclasses
from typing import TYPE_CHECKING, Any

from typing_extensions import Self

from medkit.core import dict_conv
from medkit.core.attribute import Attribute
from medkit.core.attribute_container import AttributeContainer
from medkit.core.id import generate_id
from medkit.core.text.entity_attribute_container import EntityAttributeContainer
from medkit.core.text.span import AnySpan

if TYPE_CHECKING:
    from medkit.core.store import Store


[docs] @dataclasses.dataclass(init=False) class TextAnnotation(abc.ABC, dict_conv.SubclassMapping): """Base abstract class for all text annotations Attributes ---------- uid : str Unique identifier of the annotation. label : str The label for this annotation (e.g., SENTENCE) attrs : AttributeContainer Attributes of the annotation. Stored in a :class:{~medkit.core.AttributeContainer} but can be passed as a list at init. metadata : dict of str to Any The metadata of the annotation keys : set of str Pipeline output keys to which the annotation belongs to. """ uid: str label: str attrs: AttributeContainer metadata: dict[str, Any] keys: set[str] @abc.abstractmethod def __init__( self, label: str, attrs: list[Attribute] | None = None, metadata: dict[str, Any] | None = None, uid: str | None = None, attr_container_class: type[AttributeContainer] = AttributeContainer, ): if attrs is None: attrs = [] if metadata is None: metadata = {} if uid is None: uid = generate_id() self.uid = uid self.label = label self.metadata = metadata self.keys = set() self.attrs = attr_container_class(owner_id=self.uid) for attr in attrs: self.attrs.add(attr) def __init_subclass__(cls): TextAnnotation.register_subclass(cls) super().__init_subclass__() @classmethod def from_dict(cls, ann_dict: dict[str, Any]) -> Self: subclass = cls.get_subclass_for_data_dict(ann_dict) if subclass is None: msg = ( "TextAnnotation is an abstract class. Its class method `from_dict` is" " only used for calling the correct subclass `from_dict`. Subclass is" f" {subclass}" ) raise NotImplementedError(msg) return subclass.from_dict(ann_dict) def to_dict(self) -> dict[str, Any]: raise NotImplementedError
[docs] @dataclasses.dataclass(init=False) class Segment(TextAnnotation): """Text segment referencing part of an :class:`~medkit.core.text.TextDocument`. Attributes ---------- uid : str The segment identifier. label : str The label for this segment (e.g., SENTENCE) text : str Text of the segment. spans : list of AnySpan List of spans indicating which parts of the segment text correspond to which part of the document's full text. attrs : AttributeContainer Attributes of the segment. Stored in a :class:{~medkit.core.AttributeContainer} but can be passed as a list at init. metadata : dict of str to Any The metadata of the segment keys : set of str Pipeline output keys to which the segment belongs to. """ spans: list[AnySpan] text: str def __init__( self, label: str, text: str, spans: list[AnySpan], attrs: list[Attribute] | None = None, metadata: dict[str, Any] | None = None, uid: str | None = None, store: Store | None = None, attr_container_class: type[AttributeContainer] = AttributeContainer, ): super().__init__( label=label, attrs=attrs, metadata=metadata, uid=uid, attr_container_class=attr_container_class, ) self.text = text self.spans = spans # check if spans length is equal to text length length = sum(s.length for s in self.spans) assert len(self.text) == length, "Spans length does not match text length" def to_dict(self) -> dict[str, Any]: spans = [s.to_dict() for s in self.spans] attrs = [a.to_dict() for a in self.attrs] segment_dict = { "uid": self.uid, "label": self.label, "text": self.text, "spans": spans, "attrs": attrs, "metadata": self.metadata, } dict_conv.add_class_name_to_data_dict(self, segment_dict) return segment_dict
[docs] @classmethod def from_dict(cls, segment_dict: dict[str, Any]) -> Self: """Creates a Segment from a dict Parameters ---------- segment_dict : dict of str to Any A dictionary from a serialized segment as generated by to_dict() """ spans = [AnySpan.from_dict(s) for s in segment_dict["spans"]] attrs = [Attribute.from_dict(a) for a in segment_dict["attrs"]] return cls( uid=segment_dict["uid"], label=segment_dict["label"], text=segment_dict["text"], spans=spans, attrs=attrs, metadata=segment_dict["metadata"], )
[docs] @dataclasses.dataclass(init=False) class Entity(Segment): """Text entity referencing part of an :class:`~medkit.core.text.TextDocument`. Attributes ---------- uid : str The entity identifier. label : str The label for this entity (e.g., DISEASE) text : str Text of the entity. spans : list of AnySpan List of spans indicating which parts of the entity text correspond to which part of the document's full text. attrs : EntityAttributeContainer Attributes of the entity. Stored in a :class:{~medkit.core.EntityAttributeContainer} but can be passed as a list at init. metadata : dict of str to Any The metadata of the entity keys : set of str Pipeline output keys to which the entity belongs to. """ attrs: EntityAttributeContainer def __init__( self, label: str, text: str, spans: list[AnySpan], attrs: list[Attribute] | None = None, metadata: dict[str, Any] | None = None, uid: str | None = None, store: Store | None = None, attr_container_class: type[EntityAttributeContainer] = EntityAttributeContainer, ): super().__init__(label, text, spans, attrs, metadata, uid, store, attr_container_class)
[docs] @dataclasses.dataclass(init=False) class Relation(TextAnnotation): """Relation between two text entities. Attributes ---------- uid : str The identifier of the relation label : str The relation label source_id : str The identifier of the entity from which the relation is defined target_id : str The identifier of the entity to which the relation is defined attrs : AttributeContainer The attributes of the relation metadata : dict of str to Any The metadata of the relation keys : set of str Pipeline output keys to which the relation belongs to """ source_id: str target_id: str def __init__( self, label: str, source_id: str, target_id: str, attrs: list[Attribute] | None = None, metadata: dict[str, Any] | None = None, uid: str | None = None, store: Store | None = None, attr_container_class: type[AttributeContainer] = AttributeContainer, ): super().__init__( label=label, attrs=attrs, metadata=metadata, uid=uid, attr_container_class=attr_container_class, ) self.source_id = source_id self.target_id = target_id def to_dict(self) -> dict[str, Any]: attrs = [a.to_dict() for a in self.attrs] relation_dict = { "uid": self.uid, "label": self.label, "source_id": self.source_id, "target_id": self.target_id, "attrs": attrs, "metadata": self.metadata, } dict_conv.add_class_name_to_data_dict(self, relation_dict) return relation_dict
[docs] @classmethod def from_dict(cls, relation_dict: dict[str, Any]) -> Self: """Creates a Relation from a dict Parameters ---------- relation_dict : dict of str to Any A dictionary from a serialized relation as generated by to_dict() """ attrs = [Attribute.from_dict(a) for a in relation_dict["attrs"]] return cls( uid=relation_dict["uid"], label=relation_dict["label"], source_id=relation_dict["source_id"], target_id=relation_dict["target_id"], attrs=attrs, metadata=relation_dict["metadata"], )