Source code for medkit.core.pipeline

__all__ = [
    "Pipeline",
    "PipelineStep",
    "PipelineCompatibleOperation",
    "DescribableOperation",
    "ProvCompatibleOperation",
]

import dataclasses
from typing import (
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
    cast,
)
from typing_extensions import Protocol, runtime_checkable

from medkit.core.annotation import Annotation
from medkit.core.data_item import IdentifiableDataItem, IdentifiableDataItemWithAttrs
from medkit.core.id import generate_id
from medkit.core.operation_desc import OperationDescription
from medkit.core.prov_tracer import ProvTracer


[docs]@runtime_checkable class PipelineCompatibleOperation(Protocol):
[docs] def run( self, *all_input_data: List[Any] ) -> Union[None, List[Any], Tuple[List[Any], ...]]: """ Parameters ---------- all_input_data: One or several list of data items to process (according to the number of input the operation needs) Returns ------- Union[None, List[Any], Tuple[List[Any], ...]] Tuple of list of all new data items created by the operation. Can be None if the operation does not create any new data items but rather modify existing items in-place (for instance by adding attributes to existing annotations). If there is only one list of created data items, it is possible to return directly that list without wrapping it in a tuple. """
[docs]@runtime_checkable class ProvCompatibleOperation(Protocol): def set_prov_tracer(self, prov_tracer: ProvTracer): pass
[docs]@runtime_checkable class DescribableOperation(Protocol): description: OperationDescription
[docs]@dataclasses.dataclass class PipelineStep: """`Pipeline` item describing how a processing operation is connected to other Parameters ---------- operation: The operation to use at that step input_keys: For each input of `operation`, the key to use to retrieve the corresponding annotations (either retrieved from a document or generated by an earlier pipeline step) output_keys: For each output of `operation`, the key used to pass output annotations to the next Pipeline step. Can be empty if `operation` doesn't return new annotations. aggregate_input_keys: If True, all the annotations from multiple input keys are aggregated in a single list. Defaults to False """ operation: PipelineCompatibleOperation input_keys: List[str] output_keys: List[str] aggregate_input_keys: bool = False def to_dict(self) -> Dict[str, Any]: operation = ( self.operation.description if isinstance(self.operation, DescribableOperation) else None ) return dict( operation=operation, input_keys=self.input_keys, output_keys=self.output_keys, aggregate_input_keys=self.aggregate_input_keys, )
[docs]class Pipeline: """Graph of processing operations A pipeline is made of pipeline steps, connecting together different processing operations by the use of input/output keys. Each operation can be seen as a node and the keys are its edge. Two operations can be chained by using the same string as an output key for the first operation and as an input key to the second. Steps must be added in the order of execution, there isn't any sort of dependency detection mechanism. """ def __init__( self, steps: List[PipelineStep], input_keys: List[str], output_keys: List[str], name: Optional[str] = None, uid: Optional[str] = None, ): """Initialize the pipeline Parameters ---------- steps: List of pipeline steps Steps will be executed in the order in which they were added, so make sure to add first the steps generating data used by other steps. input_keys: List of keys corresponding to the inputs passed to `run()` output_keys: List of keys corresponding to the outputs returned by `run()` name: Name describing the pipeline (defaults to the class name) uid: Identifier of the pipeline """ if uid is None: uid = generate_id() self.uid: str = uid self.name: Optional[str] = name self.steps: List[PipelineStep] = steps self.input_keys: List[str] = input_keys self.output_keys: List[str] = output_keys self._prov_tracer: Optional[ProvTracer] = None self._sub_prov_tracer: Optional[ProvTracer] = None @property def description(self) -> OperationDescription: steps = [s.to_dict() for s in self.steps] config = dict( steps=steps, input_keys=self.input_keys, output_keys=self.output_keys, ) return OperationDescription( uid=self.uid, class_name=self.__class__.__name__, name=self.__class__.__name__ if self.name is None else self.name, config=config, ) def set_prov_tracer(self, prov_tracer: ProvTracer): self._prov_tracer = prov_tracer self._sub_prov_tracer = ProvTracer(prov_tracer.store) for step in self.steps: if not isinstance(step.operation, ProvCompatibleOperation): raise TypeError( "Some operations in the pipeline steps are not" " provenance-compatible" ) step.operation.set_prov_tracer(self._sub_prov_tracer)
[docs] def run( self, *all_input_data: List[Any] ) -> Union[None, List[Any], Tuple[List[Any], ...]]: """Run the pipeline. Parameters ---------- *all_input_data: Input data expected by the pipeline, must be of same length as the pipeline `input_keys`. For each input key, the corresponding input data must be a list of items than can be of any type. Returns ------- Union[None, List[Any], Tuple[List[Any], ...]] All output data returned by the pipeline, will be of same length as the pipeline `output_keys`. For each output key, the corresponding output will be a list of items that can be of any type. If the pipeline has only one output key, then the corresponding output will be directly returned, not wrapped in a tuple. If the pipeline doesn't have any output key, nothing (ie `None`) will be returned. """ if len(all_input_data) != len(self.input_keys): raise RuntimeError( f"Number of input ({len(all_input_data)}) does not match number of" f" input keys ({len(self.input_keys)})" ) data_by_key = dict(zip(self.input_keys, all_input_data)) for step in self.steps: self._perform_step(step, data_by_key) all_output_data = tuple(data_by_key[key] for key in self.output_keys) # Keep keys only for output key for data_item in [d for data_tuple in all_output_data for d in data_tuple]: if isinstance(data_item, Annotation): data_item.keys.intersection_update(self.output_keys) if self._prov_tracer is not None: self._add_provenance(all_output_data) if len(all_output_data) == 0: # no output return None elif len(all_output_data) == 1: # unwrap out of tuple if only 1 output return all_output_data[0] else: return all_output_data
def _perform_step(self, step: PipelineStep, data_by_key: Dict[str, Any]): # find data to feed to operation all_input_data = [] for input_key in step.input_keys: input_data = data_by_key.get(input_key) if input_data is None: message = f"No data found for input key {input_key}" if any(input_key in s.output_keys for s in self.steps): message += ( "Did you add the steps in the correct order in the pipeline?" ) raise RuntimeError(message) all_input_data.append(input_data) if step.aggregate_input_keys: all_input_data = [ [ann for input_key_data in all_input_data for ann in input_key_data] ] # call operation all_output_data = step.operation.run(*all_input_data) # wrap output in tuple if necessary # (operations performing in-place modifications # have no output and return None, # operations with single output may return a # single list instead of a tuple of lists) if all_output_data is None: all_output_data = tuple() elif not isinstance(all_output_data, tuple): all_output_data = (all_output_data,) if len(all_output_data) != len(step.output_keys): raise RuntimeError( f"Number of outputs ({len(all_output_data)}) does not match number of" f" output keys ({len(step.output_keys)})" ) # store output data for output_key, output_data in zip(step.output_keys, all_output_data): if output_key not in data_by_key: data_by_key[output_key] = output_data else: data_by_key[output_key] += output_data for data_item in output_data: if isinstance(data_item, Annotation): data_item.keys.add(output_key) def _add_provenance(self, all_output_data: Tuple[List[Any], ...]): assert self._prov_tracer is not None and self._sub_prov_tracer is not None # flatten all output data to have a list of data items generated by this pipeline data_items = [ data_item for output_data in all_output_data for data_item in output_data ] # data items must have identifiers to be provenance-compatible data_items = cast(List[IdentifiableDataItem], data_items) # ugly hack for attributes generated by pipeline # they are not in all_output_data because they are not returned by the operation # but directly added to input data items (annotations), but we still want to consider # them as an output of the pipeline in terms of provenance # find all attributes that were generated by this pipeline, # ie that have provenance in the pipeline's sub_prov_tracer attrs = [ attr for data_item in data_items if isinstance(data_item, IdentifiableDataItemWithAttrs) for attr in data_item.attrs if ( self._sub_prov_tracer.has_prov(attr.uid) # ignore stub provenance with no operation # (were passed as input to the pipeline but not generated by the pipeline) and self._sub_prov_tracer.get_prov(attr.uid).op_desc is not None ) ] # add them to the list of data items generated by this pipeline data_items += attrs self._prov_tracer.add_prov_from_sub_tracer( data_items, self.description, self._sub_prov_tracer ) def check_sanity(self): steps_input_keys = [k for s in self.steps for k in s.input_keys] for input_key in self.input_keys: if input_key not in steps_input_keys: raise Exception( f"Pipeline input key {input_key} does not correspond to any" " step input key" ) steps_output_keys = [k for s in self.steps for k in s.output_keys] for output_key in self.output_keys: if output_key not in steps_output_keys: raise Exception( f"Pipeline output key {output_key} does not correspond to any" " step output key" ) for step in self.steps: for input_key in step.input_keys: if ( input_key not in steps_output_keys and input_key not in self.input_keys ): raise Exception( f"Step input key {input_key} does not correspond to any step" " output key nor any pipeline input key" ) available_keys = self.input_keys.copy() for step in self.steps: for input_key in step.input_keys: if input_key not in available_keys: raise Exception( f"Step input key {input_key} is not available yet at this step" ) available_keys += step.output_keys