Source code for medkit.text.postprocessing.document_splitter

from __future__ import annotations

__all__ = ["DocumentSplitter"]

from medkit.core import Attribute, Operation
from medkit.core.text import (
    Entity,
    ModifiedSpan,
    Relation,
    Segment,
    Span,
    TextAnnotation,
    TextDocument,
    span_utils,
)
from medkit.text.postprocessing.alignment_utils import compute_nested_segments


[docs] class DocumentSplitter(Operation): """Split text documents using its segments as a reference. The resulting 'mini-documents' contain the entities belonging to each segment along with their attributes. This operation can be used to create datasets from medkit text documents. """ def __init__( self, segment_label: str, entity_labels: list[str] | None = None, attr_labels: list[str] | None = None, relation_labels: list[str] | None = None, name: str | None = None, uid: str | None = None, ): """Instantiate the document splitter Parameters ---------- segment_label : str Label of the segments to use as references for the splitter entity_labels : list of str, optional Labels of entities to be included in the mini documents. If None, all entities from the document will be included. attr_labels : list of str, optional Labels of the attributes to be included into the new annotations. If None, all attributes will be included. relation_labels : list of str, optional Labels of relations to be included in the mini documents. If None, all relations will be included. name : str, optional Name describing the splitter (default to the class name). uid : str, Optional Identifier of the operation """ # Pass all arguments to super (remove self) init_args = locals() init_args.pop("self") super().__init__(**init_args) self.segment_label = segment_label self.entity_labels = entity_labels self.attr_labels = attr_labels self.relation_labels = relation_labels
[docs] def run(self, docs: list[TextDocument]) -> list[TextDocument]: """Split docs into mini documents Parameters ---------- docs: list of TextDocument List of text documents to split Returns ------- list of TextDocument List of documents created from the selected segments """ segment_docs = [] for doc in docs: segments = doc.anns.get_segments(label=self.segment_label) # filter entities entities = ( doc.anns.get_entities() if self.entity_labels is None else [ent for label in self.entity_labels for ent in doc.anns.get_entities(label=label)] ) # align segment and entities (fully contained) segment_and_entities = compute_nested_segments(segments, entities) # filter relations in the document relations = ( doc.anns.get_relations() if self.relation_labels is None else [rel for label in self.relation_labels for rel in doc.anns.get_relations(label=label)] ) # Iterate over all segments and corresponding nested entities for segment, nested_entities in segment_and_entities: # filter relations in nested entities entities_uid = {ent.uid for ent in nested_entities} nested_relations = [ relation for relation in relations if relation.source_id in entities_uid and relation.target_id in entities_uid ] # create new document from segment segment_doc = self._create_segment_doc( segment=segment, entities=nested_entities, relations=nested_relations, doc_source=doc, ) segment_docs.append(segment_doc) return segment_docs
def _create_segment_doc( self, segment: Segment, entities: list[Entity], relations: list[Relation], doc_source: TextDocument, ) -> TextDocument: """Create a TextDocument from a segment and its entities. The original zone of the segment becomes the text of the document. Parameters ---------- segment : Segment Segment to use as reference for the new document entities : list of Entity Entities inside the segment relations : list of Relation Relations inside the segment doc_source : TextDocument Initial document from which annotations where extracted Returns ------- TextDocument A new document with entities, the metadata includes the original span and metadata """ normalized_spans = span_utils.normalize_spans(segment.spans) # create an empty mini-doc with the raw text of the segment offset, end_span = normalized_spans[0].start, normalized_spans[-1].end metadata = doc_source.metadata.copy() metadata.update(segment.metadata) segment_doc = TextDocument(text=doc_source.text[offset:end_span], metadata=metadata) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov(segment_doc, self.description, source_data_items=[segment]) # Copy segment attributes segment_attrs = self._filter_attrs_from_ann(segment) for attr in segment_attrs: new_doc_attr = attr.copy() segment_doc.attrs.add(new_doc_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( new_doc_attr, self.description, source_data_items=[attr], ) # Add selected entities uid_mapping = {} for ent in entities: spans = [] for span in ent.spans: # relocate entity spans using segment offset if isinstance(span, Span): spans.append(Span(span.start - offset, span.end - offset)) else: replaced_spans = [Span(sp.start - offset, sp.end - offset) for sp in span.replaced_spans] spans.append(ModifiedSpan(length=span.length, replaced_spans=replaced_spans)) # define the new entity relocated_ent = Entity( text=ent.text, label=ent.label, spans=spans, metadata=ent.metadata.copy(), ) # add mapping for relations uid_mapping[ent.uid] = relocated_ent.uid # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov(relocated_ent, self.description, source_data_items=[ent]) # Copy entity attributes entity_attrs = self._filter_attrs_from_ann(ent) for attr in entity_attrs: new_ent_attr = attr.copy() relocated_ent.attrs.add(new_ent_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( new_ent_attr, self.description, source_data_items=[attr], ) # add entity to the new document segment_doc.anns.add(relocated_ent) for rel in relations: relation = Relation( label=rel.label, source_id=uid_mapping[rel.source_id], target_id=uid_mapping[rel.target_id], metadata=rel.metadata.copy(), ) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov(relation, self.description, source_data_items=[rel]) # Copy relation attributes relation_attrs = self._filter_attrs_from_ann(rel) for attr in relation_attrs: new_rel_attr = attr.copy() relation.attrs.add(new_rel_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( new_rel_attr, self.description, source_data_items=[attr], ) # add relation to the new document segment_doc.anns.add(relation) return segment_doc def _filter_attrs_from_ann(self, ann: TextAnnotation) -> list[Attribute]: """Filter attributes from an annotation using 'attr_labels'""" return ( ann.attrs.get() if self.attr_labels is None else [attr for label in self.attr_labels for attr in ann.attrs.get(label=label)] )