Source code for medkit.tools.hf_utils

"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[hf-utils]`.
"""

__all__ = ["check_model_for_task_HF"]

from pathlib import Path
from typing import Union
import transformers


[docs]def check_model_for_task_HF(model: Union[str, Path], task: str) -> bool: """Check compatibility of a model with a task HuggingFace. The model could be in the HuggingFace hub or in local files. Parameters ---------- model: Name (on the HuggingFace models hub) or path of the model. task: A string representing the HF task to check i.e : 'token-classification' Returns ------- bool Model compatibility with the task """ try: config = transformers.AutoConfig.from_pretrained(model) except Exception as err: raise ValueError("Impossible to get the task from model") from err valid_config_names = [ config_class.__name__ for supported_classes in transformers.pipelines.SUPPORTED_TASKS[task]["pt"] for config_class in supported_classes._model_mapping.keys() ] return config.__class__.__name__ in valid_config_names