Source code for medkit.text.postprocessing.document_splitter

__all__ = ["DocumentSplitter"]
# functions to create minidocs from segments
from typing import List, Optional
from medkit.core import Attribute, Operation
from medkit.core.text import (
    Entity,
    ModifiedSpan,
    Relation,
    Segment,
    Span,
    TextDocument,
    TextAnnotation,
    span_utils,
)
from medkit.text.postprocessing.alignment_utils import compute_nested_segments


[docs]class DocumentSplitter(Operation): """Split text documents using its segments as a reference. The resulting 'mini-documents' contain the entities belonging to each segment along with their attributes. This operation can be used to create datasets from medkit text documents. """ def __init__( self, segment_label: str, entity_labels: Optional[List[str]] = None, attr_labels: Optional[List[str]] = None, relation_labels: Optional[List[str]] = None, name: Optional[str] = None, uid: Optional[str] = None, ): """ Instantiate the document splitter Parameters ---------- segment_label: Label of the segments to use as references for the splitter entity_labels: Labels of entities to be included in the mini documents. If None, all entities from the document will be included. attr_labels: Labels of the attributes to be included into the new annotations. If None, all attributes will be included. relation_labels: Labels of relations to be included in the mini documents. If None, all relations will be included. name: Name describing the splitter (default to the class name). uid: str, Optional Identifier of the operation """ # Pass all arguments to super (remove self) init_args = locals() init_args.pop("self") super().__init__(**init_args) self.segment_label = segment_label self.entity_labels = entity_labels self.attr_labels = attr_labels self.relation_labels = relation_labels
[docs] def run(self, docs: List[TextDocument]) -> List[TextDocument]: """Split docs into mini documents Parameters ---------- documents: List of text documents to split Returns ------- List of documents created from the selected segments """ segment_docs = [] for doc in docs: segments = doc.anns.get_segments(label=self.segment_label) # filter entities entities = ( doc.anns.get_entities() if self.entity_labels is None else [ ent for label in self.entity_labels for ent in doc.anns.get_entities(label=label) ] ) # align segment and entities (fully contained) segment_and_entities = compute_nested_segments(segments, entities) # filter relations in the document relations = ( doc.anns.get_relations() if self.relation_labels is None else [ rel for label in self.relation_labels for rel in doc.anns.get_relations(label=label) ] ) # Iterate over all segments and corresponding nested entities for segment, nested_entities in segment_and_entities: # filter relations in nested entities entities_uid = set(ent.uid for ent in nested_entities) nested_relations = [ relation for relation in relations if relation.source_id in entities_uid and relation.target_id in entities_uid ] # create new document from segment segment_doc = self._create_segment_doc( segment=segment, entities=nested_entities, relations=nested_relations, doc_source=doc, ) segment_docs.append(segment_doc) return segment_docs
def _create_segment_doc( self, segment: Segment, entities: List[Entity], relations: List[Relation], doc_source: TextDocument, ) -> TextDocument: """Create a TextDocument from a segment and its entities. The original zone of the segment becomes the text of the document. Parameters ---------- segment: Segment to use as reference for the new document entities: Entities inside the segment relations: Relations inside the segment doc_source: Initial document from which annotations where extracted Returns ------- TextDocument A new document with entities, the metadata includes the original span and metadata """ normalized_spans = span_utils.normalize_spans(segment.spans) # create an empty mini-doc with the raw text of the segment offset, end_span = normalized_spans[0].start, normalized_spans[-1].end metadata = doc_source.metadata.copy() metadata.update(segment.metadata) segment_doc = TextDocument( text=doc_source.text[offset:end_span], metadata=metadata ) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( segment_doc, self.description, source_data_items=[segment] ) # Copy segment attributes segment_attrs = self._filter_attrs_from_ann(segment) for attr in segment_attrs: new_doc_attr = attr.copy() segment_doc.attrs.add(new_doc_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( new_doc_attr, self.description, source_data_items=[attr], ) # Add selected entities uid_mapping = {} for ent in entities: spans = [] for span in ent.spans: # relocate entity spans using segment offset if isinstance(span, Span): spans.append(Span(span.start - offset, span.end - offset)) else: replaced_spans = [ Span(sp.start - offset, sp.end - offset) for sp in span.replaced_spans ] spans.append( ModifiedSpan(length=span.length, replaced_spans=replaced_spans) ) # define the new entity relocated_ent = Entity( text=ent.text, label=ent.label, spans=spans, metadata=ent.metadata.copy(), ) # add mapping for relations uid_mapping[ent.uid] = relocated_ent.uid # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( relocated_ent, self.description, source_data_items=[ent] ) # Copy entity attributes entity_attrs = self._filter_attrs_from_ann(ent) for attr in entity_attrs: new_ent_attr = attr.copy() relocated_ent.attrs.add(new_ent_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( new_ent_attr, self.description, source_data_items=[attr], ) # add entity to the new document segment_doc.anns.add(relocated_ent) for rel in relations: relation = Relation( label=rel.label, source_id=uid_mapping[rel.source_id], target_id=uid_mapping[rel.target_id], metadata=rel.metadata.copy(), ) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( relation, self.description, source_data_items=[rel] ) # Copy relation attributes relation_attrs = self._filter_attrs_from_ann(rel) for attr in relation_attrs: new_rel_attr = attr.copy() relation.attrs.add(new_rel_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov( new_rel_attr, self.description, source_data_items=[attr], ) # add relation to the new document segment_doc.anns.add(relation) return segment_doc def _filter_attrs_from_ann(self, ann: TextAnnotation) -> List[Attribute]: """Filter attributes from an annotation using 'attr_labels'""" attrs = ( ann.attrs.get() if self.attr_labels is None else [ attr for label in self.attr_labels for attr in ann.attrs.get(label=label) ] ) return attrs