__all__ = ["DocumentSplitter"]
# functions to create minidocs from segments
from typing import List, Optional
from medkit.core import Attribute, Operation
from medkit.core.text import (
Entity,
ModifiedSpan,
Relation,
Segment,
Span,
TextDocument,
TextAnnotation,
span_utils,
)
from medkit.text.postprocessing.alignment_utils import compute_nested_segments
[docs]class DocumentSplitter(Operation):
"""Split text documents using its segments as a reference.
The resulting 'mini-documents' contain the entities belonging to each
segment along with their attributes.
This operation can be used to create datasets from medkit text documents.
"""
def __init__(
self,
segment_label: str,
entity_labels: Optional[List[str]] = None,
attr_labels: Optional[List[str]] = None,
relation_labels: Optional[List[str]] = None,
name: Optional[str] = None,
uid: Optional[str] = None,
):
"""
Instantiate the document splitter
Parameters
----------
segment_label:
Label of the segments to use as references for the splitter
entity_labels:
Labels of entities to be included in the mini documents.
If None, all entities from the document will be included.
attr_labels:
Labels of the attributes to be included into the new annotations.
If None, all attributes will be included.
relation_labels:
Labels of relations to be included in the mini documents.
If None, all relations will be included.
name:
Name describing the splitter (default to the class name).
uid: str, Optional
Identifier of the operation
"""
# Pass all arguments to super (remove self)
init_args = locals()
init_args.pop("self")
super().__init__(**init_args)
self.segment_label = segment_label
self.entity_labels = entity_labels
self.attr_labels = attr_labels
self.relation_labels = relation_labels
[docs] def run(self, docs: List[TextDocument]) -> List[TextDocument]:
"""Split docs into mini documents
Parameters
----------
documents:
List of text documents to split
Returns
-------
List of documents created from the selected segments
"""
segment_docs = []
for doc in docs:
segments = doc.anns.get_segments(label=self.segment_label)
# filter entities
entities = (
doc.anns.get_entities()
if self.entity_labels is None
else [
ent
for label in self.entity_labels
for ent in doc.anns.get_entities(label=label)
]
)
# align segment and entities (fully contained)
segment_and_entities = compute_nested_segments(segments, entities)
# filter relations in the document
relations = (
doc.anns.get_relations()
if self.relation_labels is None
else [
rel
for label in self.relation_labels
for rel in doc.anns.get_relations(label=label)
]
)
# Iterate over all segments and corresponding nested entities
for segment, nested_entities in segment_and_entities:
# filter relations in nested entities
entities_uid = set(ent.uid for ent in nested_entities)
nested_relations = [
relation
for relation in relations
if relation.source_id in entities_uid
and relation.target_id in entities_uid
]
# create new document from segment
segment_doc = self._create_segment_doc(
segment=segment,
entities=nested_entities,
relations=nested_relations,
doc_source=doc,
)
segment_docs.append(segment_doc)
return segment_docs
def _create_segment_doc(
self,
segment: Segment,
entities: List[Entity],
relations: List[Relation],
doc_source: TextDocument,
) -> TextDocument:
"""Create a TextDocument from a segment and its entities.
The original zone of the segment becomes the text of the document.
Parameters
----------
segment:
Segment to use as reference for the new document
entities:
Entities inside the segment
relations:
Relations inside the segment
doc_source:
Initial document from which annotations where extracted
Returns
-------
TextDocument
A new document with entities, the metadata includes the original span and metadata
"""
normalized_spans = span_utils.normalize_spans(segment.spans)
# create an empty mini-doc with the raw text of the segment
offset, end_span = normalized_spans[0].start, normalized_spans[-1].end
metadata = doc_source.metadata.copy()
metadata.update(segment.metadata)
segment_doc = TextDocument(
text=doc_source.text[offset:end_span], metadata=metadata
)
# handle provenance
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
segment_doc, self.description, source_data_items=[segment]
)
# Copy segment attributes
segment_attrs = self._filter_attrs_from_ann(segment)
for attr in segment_attrs:
new_doc_attr = attr.copy()
segment_doc.attrs.add(new_doc_attr)
# handle provenance
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
new_doc_attr,
self.description,
source_data_items=[attr],
)
# Add selected entities
uid_mapping = {}
for ent in entities:
spans = []
for span in ent.spans:
# relocate entity spans using segment offset
if isinstance(span, Span):
spans.append(Span(span.start - offset, span.end - offset))
else:
replaced_spans = [
Span(sp.start - offset, sp.end - offset)
for sp in span.replaced_spans
]
spans.append(
ModifiedSpan(length=span.length, replaced_spans=replaced_spans)
)
# define the new entity
relocated_ent = Entity(
text=ent.text,
label=ent.label,
spans=spans,
metadata=ent.metadata.copy(),
)
# add mapping for relations
uid_mapping[ent.uid] = relocated_ent.uid
# handle provenance
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
relocated_ent, self.description, source_data_items=[ent]
)
# Copy entity attributes
entity_attrs = self._filter_attrs_from_ann(ent)
for attr in entity_attrs:
new_ent_attr = attr.copy()
relocated_ent.attrs.add(new_ent_attr)
# handle provenance
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
new_ent_attr,
self.description,
source_data_items=[attr],
)
# add entity to the new document
segment_doc.anns.add(relocated_ent)
for rel in relations:
relation = Relation(
label=rel.label,
source_id=uid_mapping[rel.source_id],
target_id=uid_mapping[rel.target_id],
metadata=rel.metadata.copy(),
)
# handle provenance
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
relation, self.description, source_data_items=[rel]
)
# Copy relation attributes
relation_attrs = self._filter_attrs_from_ann(rel)
for attr in relation_attrs:
new_rel_attr = attr.copy()
relation.attrs.add(new_rel_attr)
# handle provenance
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
new_rel_attr,
self.description,
source_data_items=[attr],
)
# add relation to the new document
segment_doc.anns.add(relation)
return segment_doc
def _filter_attrs_from_ann(self, ann: TextAnnotation) -> List[Attribute]:
"""Filter attributes from an annotation using 'attr_labels'"""
attrs = (
ann.attrs.get()
if self.attr_labels is None
else [
attr
for label in self.attr_labels
for attr in ann.attrs.get(label=label)
]
)
return attrs