Source code for medkit.text.postprocessing.alignment_utils

__all__ = ["compute_nested_segments"]
from typing import Tuple, List
from intervaltree import IntervalTree

from medkit.core.text import Segment, span_utils


def _create_segments_tree(
    target_segments: List[Segment],
) -> IntervalTree:
    """Use the normalized spans of the segments to create an interval tree

    Parameters
    ----------
    target_segments:
        List of segments to align

    Returns
    -------
    IntervalTree
        Interval tree from the target segments"""
    tree = IntervalTree()
    for segment in target_segments:
        normalized_spans = span_utils.normalize_spans(segment.spans)

        if not normalized_spans:
            continue

        tree.addi(
            normalized_spans[0].start,
            normalized_spans[-1].end,
            data=segment,
        )
    return tree


[docs]def compute_nested_segments( source_segments: List[Segment], target_segments: List[Segment] ) -> List[Tuple[Segment, List[Segment]]]: """Return source segments aligned with its nested segments. Parameters ---------- source_segments: List of source segments target_segments: List of segments to align Returns ------- List[Tuple[~medkit.core.text.Segment,List[~medkit.core.text.Segment]]]: List of aligned segments """ tree = _create_segments_tree(target_segments) nested = [] for parent in source_segments: normalized_spans = span_utils.normalize_spans(parent.spans) if not normalized_spans: continue start, end = normalized_spans[0].start, normalized_spans[-1].end children = [child.data for child in tree.overlap(start, end)] nested.append((parent, children)) return nested