Source code for medkit.tools.hf_utils
"""
This module needs extra-dependencies not installed as core dependencies of medkit.
To install them, use `pip install medkit-lib[hf-utils]`.
"""
__all__ = ["check_model_for_task_HF"]
from pathlib import Path
from typing import Optional, Union
import transformers
[docs]def check_model_for_task_HF(
model: Union[str, Path], task: str, hf_auth_token: Optional[str] = None
) -> bool:
"""Check compatibility of a model with a task HuggingFace.
The model could be in the HuggingFace hub or in local files.
Parameters
----------
model:
Name (on the HuggingFace models hub) or path of the model.
task:
A string representing the HF task to check i.e : 'token-classification'
hf_auth_token:
HuggingFace Authentication token (to access private models on the
hub)
Returns
-------
bool
Model compatibility with the task
"""
try:
config = transformers.AutoConfig.from_pretrained(model, token=hf_auth_token)
except Exception as err:
raise ValueError("Impossible to get the task from model") from err
valid_config_names = [
config_class.__name__
for supported_classes in transformers.pipelines.SUPPORTED_TASKS[task]["pt"]
for config_class in supported_classes._model_mapping.keys()
]
return config.__class__.__name__ in valid_config_names