"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[umls-coder-normalizer]`.
"""
__all__ = ["UMLSCoderNormalizer"]
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from typing_extensions import Literal
from pathlib import Path
import pandas as pd
import torch
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer, FeatureExtractionPipeline
import yaml
from medkit.core import Operation
from medkit.core.text import Entity
import medkit.core.utils
from medkit.text.ner.umls_norm_attribute import UMLSNormAttribute
from medkit.text.ner.umls_utils import (
load_umls,
preprocess_term_to_match,
guess_umls_version,
)
_PARAMS_FILENAME = "params.yml"
_TERMS_FILENAME = "terms.feather"
_UMLS_EMBEDDINGS_CHUNK_SIZE = 65536
_UMLS_EMBEDDINGS_FILE_EXT = ".pt"
class _UMLSEmbeddingsParams(NamedTuple):
umls_version: str
language: str
model: str
summary_method: Literal["mean", "cls"]
normalize_embeddings: bool
lowercase: bool
normalize_unicode: bool
def to_dict(self) -> Dict[str, Any]:
return dict(
umls_version=self.umls_version,
language=self.language,
model=self.model,
summary_method=self.summary_method,
normalize_embeddings=self.normalize_embeddings,
lowercase=self.lowercase,
normalize_unicode=self.normalize_unicode,
)
[docs]class UMLSCoderNormalizer(Operation):
"""Normalizer adding UMLS normalization attributes to
pre-existing entities. Based on https://github.com/GanjinZero/CODER/.
An UMLS `MRCONSO.RRF` file is needed. The normalizer identifies UMLS concepts by
comparing embeddings of reference UMLS terms with the embeddings of the input
entities. Any text transformer model from the HuggingFace Hub can be used,
but "GanjinZero/UMLSBert_ENG" was specifically trained for this task (for english).
When `UMLSCoderNormalizer` is used for the first time for a given `MRCONSO.RRF`,
the embeddings of all umls terms are pre-computed (this can take a very long time)
and stored in `embeddings_cache_dir`, so they can be reused next time.
If another `MRCONSO.RRF` file is used, or if a parameter impacting the computation
of embeddings (`model`, `summary_method`, etc) is changed, then another `embeddings_cache_dir`
must be used, or `embeddings_cache_dir` must be deleted so it can be created properly.
If the umls embeddings are too big to be held in memory, use `nb_umls_embeddings_chunks`.
"""
def __init__(
self,
umls_mrconso_file: Union[str, Path],
language: str,
model: Union[str, Path],
embeddings_cache_dir: Union[str, Path],
summary_method: Literal["mean", "cls"] = "cls",
normalize_embeddings: bool = True,
lowercase: bool = False,
normalize_unicode: bool = False,
threshold: Optional[float] = None,
max_nb_matches: int = 1,
device: int = -1,
batch_size: int = 128,
nb_umls_embeddings_chunks: Optional[int] = None,
hf_cache_dir: Optional[Union[str, Path]] = None,
name: Optional[str] = None,
uid: Optional[str] = None,
):
"""
Parameters
----------
umls_mrconso_file:
Path to the UMLS `MRCONSO.RRF` file.
language:
Language of the UMLS terms to use (ex: `"ENG"`, `"FRE"`).
model:
Name on the Hugging Face hub or path to the transformers model that will be used to extract
embeddings (ex: `"GanjinZero/UMLSBert_ENG"`).
embeddings_cache_dir:
Path to the directory into which pre-computed embeddings of UMLS terms should be cached.
If it doesn't exist yet, the embeddings will be automatically generated (it can take a long
time) and stored there, ready to be reused on further instantiations.
If it already exists, a check will be done to make sure the params used when the embeddings
were computed are consistent with the params of the current instance.
summary_method:
If set to `"mean"`, the embeddings extracted will be the mean of the pooling layers
of the model. Otherwise, when set to `"cls"`, the last hidden layer will be used.
normalize_embeddings:
Whether to normalize the extracted embeddings.
lowercase:
Whether to use lowercased versions of UMLS terms and input entities.
normalize_unicode:
Whether to ASCII-only versions of UMLS terms and input entities
(non-ASCII chars replaced by closest ASCII chars).
threshold:
Minimum similarity threshold (between 0.0 and 1.0) between the embeddings
of an entity and of an UMLS term for a normalization attribute to be added.
max_nb_matches:
Maximum number of normalization attributes to add to each entity.
device:
Device to use for transformers models. Follows the Hugging Face convention
(-1 for "cpu" and device number for gpu, for instance 0 for "cuda:0").
batch_size:
Number of entities in batches processed by the embeddings extraction pipeline.
hf_cache_dir:
Directory where to store downloaded models. If not set, the default
HuggingFace cache dir is used.
nb_umls_embeddings_chunks:
Number of umls embeddings chunks to load at the same time when computing
embeddings similarities. (a chunk contains 65536 embeddings).
If `None`, all pre-computed umls embeddings are pre-loaded in memory and
similaries are computed in one shot. Otherwise, at each call to `run()`,
umls embeddings are loaded by groups of chunks and similaries are computed
for each group.
Use this when umls embeddings are too big to be fully loaded in memory.
The higher this value, the more memory needed.
name:
Name describing the normalizer (defaults to the class name).
uid:
Identifier of the normalizer.
"""
# Pass all arguments to super (remove self)
init_args = locals()
init_args.pop("self")
super().__init__(**init_args)
self.umls_mrconso_file = Path(umls_mrconso_file)
self.embeddings_cache_dir = Path(embeddings_cache_dir)
self.language = language
self.model = model
self.summary_method = summary_method
self.normalize_embeddings = normalize_embeddings
self.lowercase = lowercase
self.normalize_unicode = normalize_unicode
self.threshold = threshold
self.max_nb_matches = max_nb_matches
self.device = device
self.nb_umls_embeddings_chunks = nb_umls_embeddings_chunks
self._pipeline: _EmbeddingsPipeline = transformers.pipeline(
"feature-extraction",
model=self.model,
pipeline_class=_EmbeddingsPipeline,
summary_method=summary_method,
normalize=normalize_embeddings,
device=device,
batch_size=batch_size,
model_kwargs={"cache_dir": hf_cache_dir},
)
# guess UMLS version
self._umls_version = guess_umls_version(umls_mrconso_file)
# pre-compute embeddings of UMLS terms if necessary
self._build_umls_embeddings()
# preload all pre-computed UMLS embeddings if nb_umls_embeddings_chunks not set
if self.nb_umls_embeddings_chunks is None:
umls_embeddings_files = sorted(
self.embeddings_cache_dir.glob(f"*{_UMLS_EMBEDDINGS_FILE_EXT}")
)
self._umls_embeddings = self._load_umls_embeddings(umls_embeddings_files)
else:
self._umls_embeddings = None
# load corresponding UMLS terms and associated CUIs
umls_terms_file = self.embeddings_cache_dir / _TERMS_FILENAME
self._umls_entries = pd.read_feather(umls_terms_file)
[docs] def run(self, entities: List[Entity]):
"""Add normalization attributes to each entity in `entities`.
Each entity will have zero, one or more normalization attributes depending
on `max_nb_matches` and on how many matches with a similarity above `threshold`
are found.
Parameters
----------
entities:
List of entities to add normalization attributes to
"""
# find best matches and assocatied score for all entities
all_match_indices, all_match_scores = self._find_best_matches(entities)
# add normalization attributes to each entity
for entity, match_indices, match_scores in zip(
entities, all_match_indices, all_match_scores
):
self._normalize_entity(entity, match_indices, match_scores)
def _find_best_matches(
self, entities: List[Entity]
) -> Tuple[List[List[int]], List[List[float]]]:
entity_terms = [entity.text for entity in entities]
entity_embeddings = self._pipeline(entity_terms)
entity_embeddings = torch.cat(entity_embeddings, dim=0)
if self.nb_umls_embeddings_chunks is not None:
# compute similarities for each batch of pre-computed umls embeddings
all_similarities = []
umls_embeddings_files = sorted(
self.embeddings_cache_dir.glob(f"*{_UMLS_EMBEDDINGS_FILE_EXT}")
)
for files in medkit.core.utils.batch_list(
umls_embeddings_files, self.nb_umls_embeddings_chunks
):
umls_embeddings = self._load_umls_embeddings(files)
all_similarities.append(
torch.matmul(entity_embeddings, umls_embeddings.T)
)
similarities = torch.cat(all_similarities, dim=1)
else:
# compute similarity on all pre-loaded pre-computed umls embeddings
assert self._umls_embeddings is not None
similarities = torch.matmul(entity_embeddings, self._umls_embeddings.T)
all_matches_scores, all_matches_indices = torch.topk(
similarities, k=self.max_nb_matches
)
# round scores to avoid floating point precision errors and get
# 1.0 for exact matches instead of values slightly above or below
all_matches_scores = torch.round(all_matches_scores, decimals=4)
return all_matches_indices.tolist(), all_matches_scores.tolist()
def _load_umls_embeddings(self, files: List[Path]) -> torch.Tensor:
torch_device = "cpu" if self.device < 0 else f"cuda:{self.device}"
umls_embeddings = torch.cat(
[torch.load(file, map_location=torch_device) for file in files]
)
return umls_embeddings
def _normalize_entity(
self, entity: Entity, match_indices: List[int], match_scores: List[float]
):
for match_index, match_score in zip(match_indices, match_scores):
if self.threshold is not None and match_score < self.threshold:
continue
umls_entry = self._umls_entries.iloc[match_index]
norm_attr = UMLSNormAttribute(
cui=umls_entry.cui,
umls_version=self._umls_version,
term=umls_entry.term,
score=match_score,
)
entity.attrs.add(norm_attr)
if self._prov_tracer is not None:
self._prov_tracer.add_prov(
norm_attr, self.description, source_data_items=[entity]
)
def _build_umls_embeddings(self, show_progress=True):
# build description of computation params
params = _UMLSEmbeddingsParams(
umls_version=self._umls_version,
language=self.language,
model=self.model,
summary_method=self.summary_method,
normalize_embeddings=self.normalize_embeddings,
lowercase=self.lowercase,
normalize_unicode=self.normalize_unicode,
)
# check if embeddings have already been computed
params_file = self.embeddings_cache_dir / _PARAMS_FILENAME
if params_file.exists():
# check consistency of params
with open(params_file) as fp:
existing_params = _UMLSEmbeddingsParams(**yaml.safe_load(fp))
if existing_params != params:
raise Exception(
f"Cache directory {self.embeddings_cache_dir} contains UMLS"
f" embeddings pre-computed with different params: {params} vs"
f" {existing_params}"
)
# nothing to do, embeddings have already been computed
return
if show_progress:
print(
"No pre-existing UMLS embeddings found in cache directory"
f" {self.embeddings_cache_dir}, pre-computing them right now"
)
self.embeddings_cache_dir.mkdir(exist_ok=True)
# remove all previous embedding files for safety
[
f.unlink()
for f in self.embeddings_cache_dir.glob(f"*{_UMLS_EMBEDDINGS_FILE_EXT}")
]
# get iterator to all UMLS entries
entries_iter = load_umls(
self.umls_mrconso_file,
languages=[self.language],
show_progress=show_progress,
)
entries_by_term_to_match = {}
# iterate over chunks of umls entries
if show_progress:
print("Loading UMLS entries and computing embeddings...")
for i, chunk_entries in enumerate(
medkit.core.utils.batch_iter(entries_iter, _UMLS_EMBEDDINGS_CHUNK_SIZE)
):
# get preprocess version of each term for matching
terms_to_match = [
preprocess_term_to_match(e.term, self.lowercase, self.normalize_unicode)
for e in chunk_entries
]
# skip entry with duplicate terms to match
chunk_entries_by_term_to_match = {
term_to_match: entry
for term_to_match, entry in zip(terms_to_match, chunk_entries)
if term_to_match not in entries_by_term_to_match
}
if not chunk_entries_by_term_to_match:
continue
entries_by_term_to_match.update(chunk_entries_by_term_to_match)
# compute embedding for each term
chunk_embeddings_iter = self._pipeline(
term_to_match for term_to_match in chunk_entries_by_term_to_match
)
chunk_embeddings = torch.cat(list(chunk_embeddings_iter), dim=0)
chunk_embeddings_file = (
self.embeddings_cache_dir / f"{i:010d}{_UMLS_EMBEDDINGS_FILE_EXT}"
)
# save chunk embeddings
torch.save(chunk_embeddings, chunk_embeddings_file)
# store entries in feather file (faster than csv and yaml)
entries_df = pd.DataFrame.from_records(
[e.to_dict() for e in entries_by_term_to_match.values()]
)
terms_file = self.embeddings_cache_dir / _TERMS_FILENAME
if show_progress:
print("Writing UMLS terms... ", end="")
entries_df.to_feather(terms_file)
if show_progress:
print("Done")
# store params into yaml
with open(params_file, mode="w") as fp:
yaml.safe_dump(
params.to_dict(),
fp,
encoding="utf-8",
allow_unicode=True,
sort_keys=False,
)
class _EmbeddingsPipeline(FeatureExtractionPipeline):
"""Extract embeddings from a pipeline"""
_EPS = 1e-12
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
summary_method: Literal["mean", "cls"],
normalize: bool = True,
*args,
**kwargs,
):
super().__init__(model, tokenizer, *args, **kwargs)
self.summary_method = summary_method
self.normalize = normalize
def preprocess(self, inputs, truncation=True) -> Dict[str, torch.Tensor]:
return self.tokenizer(
inputs,
max_length=32,
add_special_tokens=True,
truncation=truncation,
padding="max_length",
return_tensors="pt",
)
def postprocess(self, model_outputs) -> torch.Tensor:
if self.summary_method == "cls":
embeddings = model_outputs[1]
else:
assert self.summary_method == "mean"
embeddings = torch.mean(model_outputs[0], dim=1)
if self.normalize:
norm = torch.norm(embeddings, p=2, dim=1, keepdim=True).clamp(min=self._EPS)
embeddings = embeddings / norm
return embeddings