Source code for medkit.text.ner.nlstruct_entity_matcher

"""This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[nlstruct]`.
"""
from __future__ import annotations

__all__ = ["NLStructEntityMatcher"]

from pathlib import Path
from typing import Iterator

import huggingface_hub
import nlstruct
import torch

from medkit.core import Attribute
from medkit.core.text import Entity, NEROperation, Segment, span_utils

_PYTORCH_FILES = ["*.ckpt", "*.pt"]
_TXT_FILES = "*.txt"

# name of nlstruct registry to fix the config
_CONCAT_REGISTRY_NAME = "concat"
_EMBEDDING_REGISTRY_NAME = "word_embeddings"


[docs] class NLStructEntityMatcher(NEROperation): """Entity matcher based on a NLstruct InformationExtraction model. The matcher expects a directory with a torch checkpoint and a text file if the model was pretrained using word embeddings. The paper [1]_ presents a model trained with the NLstruct [2]_ library and the mimic learning approach. The model used a private teacher model to annotate the unlabeled [CAS clinical French corpus](https://aclanthology.org/W18-5614/). The weights of the CAS student model are shared via the HuggingFace Hub and you can load the model using the following model name `NesrineBannour/CAS-privacy-preserving-model` to create a NLstructEntityMatcher. References ---------- .. [1] Nesrine Bannour, Perceval Wajsbürt, Bastien Rance, Xavier Tannier, and Aurélie Névéol. 2022. Privacy-preserving mimic models for clinical named entity recognition in French. Journal of Biomedical Informatics 130, (2022), 104073. DOI: https://doi.org/https://doi.org/10.1016/j.jbi.2022.104073 .. [2] Perceval Wajsbürt. 2021. Extraction and normalization of simple and structured entities in medical documents. Theses. Sorbonne Université. Retrieved from https://hal.archives-ouvertes.fr/tel-03624928 """ def __init__( self, model_name_or_dirpath: str | Path, attrs_to_copy: list[str] | None = None, device: int = -1, hf_auth_token: str | None = None, cache_dir: str | Path | None = None, name: str | None = None, uid: str | None = None, ): """Parameters ---------- model_name_or_dirpath : str or Path Name (on the HuggingFace models hub) or dirpath of the NLstruct model. The model dir must contain a PyTorch file ('.cpkt','.pt') and a text file (.txt) representing the FastText embeddings if required. attrs_to_copy : list of str, optional Labels of the attributes that should be copied from the input segment to the created entity. Useful for propagating context attributes (negation, antecendent, etc). device : int, default=-1 Device to use for the NLstruct model. Follows the HuggingFace convention (-1 for "cpu" and device number for gpu, for instance 0 for "cuda:0"). hf_auth_token : str, optional HuggingFace Authentication token (to access private models on the hub) cache_dir : str or Path, optional Directory where to store downloaded models. If not set, the default HuggingFace cache dir is used. name : str, optional Name describing the matcher (defaults to the class name). uid : str, optional Identifier of the matcher. """ # Pass all arguments to super (remove self and confidential hf_auth_token) init_args = locals() init_args.pop("self") init_args.pop("hf_auth_token") super().__init__(**init_args) if attrs_to_copy is None: attrs_to_copy = [] self.cache_dir = cache_dir self.attrs_to_copy = attrs_to_copy self.model_name_or_dirpath = Path(model_name_or_dirpath) # get checkpoint dir if self.model_name_or_dirpath.exists(): checkpoint_dir = self.model_name_or_dirpath else: allow_patterns = [*_PYTORCH_FILES, _TXT_FILES] # download only allowed files checkpoint_dir = huggingface_hub.snapshot_download( repo_id=str(model_name_or_dirpath), cache_dir=self.cache_dir, allow_patterns=allow_patterns, token=hf_auth_token, ) checkpoint_dir = Path(checkpoint_dir) self.device = torch.device("cpu" if device < 0 else f"cuda:{device}") self.model = self._load_from_checkpoint_dir(checkpoint_dir, self.device) self.model.eval() @staticmethod def _load_from_checkpoint_dir(checkpoint_dir: Path, device): """Get the location of the checkpoint and fix the path of the Fast Text file in the configuration. Return the nlstruct model created with the modified config. """ checkpoint_filepaths = [filepath for pattern in _PYTORCH_FILES for filepath in checkpoint_dir.glob(pattern)] if not len(checkpoint_filepaths): msg = f"There was no PyTorch file with a NLstruct checkpoint in '{checkpoint_dir.name}'" raise FileNotFoundError(msg) # BUGFIX: (nlstruct) The config created from nlstruct defines a filename # without a relative path. This means that the text file needs to be in # the same place where the object is created. # Cf. 'nlstruct.common.WordEmbeddings' # Force the filename to use the checkpoint directory checkpoint_filepath = checkpoint_filepaths[0] loaded = torch.load(checkpoint_filepath, map_location=device) config = loaded["config"] # modify config if the encoder is a 'concat'model if config["encoder"]["module"] == _CONCAT_REGISTRY_NAME: # looks for the 'word_embeddings' config key_and_filename = [ (key, data["filename"]) for key, data in config["encoder"]["encoders"].items() if data["module"] == _EMBEDDING_REGISTRY_NAME ] if len(key_and_filename) != 0: key, filename = key_and_filename[0] # if 'filename' is empty, pretrained without embeddings (c.f nlstruct) # keep the same config if filename: new_path = checkpoint_dir / Path(filename).name if not new_path.exists(): msg = f"The text file '{new_path}' with the fast text embeddings does not exist" raise ValueError(msg) # update the filename of the wordEmbeddings model config["encoder"]["encoders"][key]["filename"] = str(new_path) # similar to nlstruct load pretrained # create the model using modified config model = nlstruct.get_instance(config) model.load_state_dict(loaded["state_dict"], strict=False) return model
[docs] def run(self, segments: list[Segment]) -> list[Entity]: """Return entities for each match in `segments`. Parameters ---------- segments : list of Segment List of segments into which to look for matches. Returns ------- list of Entity Entities found in `segments`. """ # predict matches by segments entities = [] for segment in segments: matches = self.model.predict({"doc_id": segment.uid, "text": segment.text}) entities.extend(self._matches_to_entities(matches, segment)) return entities
def _matches_to_entities(self, matches: list[dict], segment: Segment) -> Iterator[Entity]: for match in matches["entities"]: text_all, spans_all = [], [] # build entity by fragments for fragment in match["fragments"]: text, spans = span_utils.extract(segment.text, segment.spans, [(fragment["begin"], fragment["end"])]) text_all.append(text) spans_all.extend(spans) text_all = "".join(text_all) # support multilabel label = match["label"] if isinstance(match["label"], str) else "-".join(match["label"]) entity = Entity( label=label, text=text_all, spans=spans_all, ) # TBD: This confidence is not well described, # normally around 0.99, round to avoid problems in export score_attr = Attribute(label="confidence", value=float("{:.2f}".format(match["confidence"]))) entity.attrs.add(score_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov(entity, self.description, source_data_items=[segment]) self._prov_tracer.add_prov(score_attr, self.description, source_data_items=[segment]) # copy attrs from segment for label in self.attrs_to_copy: for attr in segment.attrs.get(label=label): copied_attr = attr.copy() entity.attrs.add(copied_attr) # handle provenance if self._prov_tracer is not None: self._prov_tracer.add_prov(copied_attr, self.description, [attr]) yield entity