Source code for medkit.tools.hf_utils

"""This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[hf-utils]`.
"""
from __future__ import annotations

__all__ = ["check_model_for_task_hf"]


from typing import TYPE_CHECKING

import transformers

if TYPE_CHECKING:
    from pathlib import Path


[docs] def check_model_for_task_hf(model: str | Path, task: str, hf_auth_token: str | None = None) -> bool: """Check compatibility of a model with a task HuggingFace. The model could be in the HuggingFace hub or in local files. Parameters ---------- model : str or Path Name (on the HuggingFace models hub) or path of the model. task : str A string representing the HF task to check i.e : 'token-classification' hf_auth_token : str, optional HuggingFace Authentication token (to access private models on the hub) Returns ------- bool Model compatibility with the task """ try: config = transformers.AutoConfig.from_pretrained(model, token=hf_auth_token) except ValueError as err: msg = "Impossible to get the task from model" raise ValueError(msg) from err valid_config_names = [ config_class.__name__ for supported_classes in transformers.pipelines.SUPPORTED_TASKS[task]["pt"] for config_class in supported_classes._model_mapping ] return config.__class__.__name__ in valid_config_names