diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index aade28610b313..4b76509e4541f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,12 +4,14 @@ import enum import json import os import time +from functools import cache from pathlib import Path -from typing import Any, Dict, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, Literal, Optional, Type, Union import huggingface_hub -from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, - try_to_load_from_cache) +from huggingface_hub import hf_hub_download +from huggingface_hub import list_repo_files as hf_list_repo_files +from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, LocalEntryNotFoundError, RepositoryNotFoundError, @@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum): MISTRAL = "mistral" +def with_retry(func: Callable[[], Any], + log_msg: str, + max_retries: int = 2, + retry_delay: int = 2): + for attempt in range(max_retries): + try: + return func() + except Exception as e: + if attempt == max_retries - 1: + logger.error("%s: %s", log_msg, e) + raise + logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1, + max_retries) + time.sleep(retry_delay) + retry_delay *= 2 + + +# @cache doesn't cache exceptions +@cache +def list_repo_files( + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> list[str]: + + def lookup_files(): + try: + return hf_list_repo_files(repo_id, + revision=revision, + repo_type=repo_type, + token=token) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, + # all we know is that we don't have this + # file cached. + return [] + + return with_retry(lookup_files, "Error retrieving file list") + + +def file_exists( + repo_id: str, + file_name: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> bool: + + file_list = list_repo_files(repo_id, + repo_type=repo_type, + revision=revision, + token=token) + return file_name in file_list + + +# In offline mode the result can be a false negative def file_or_path_exists(model: Union[str, Path], config_name: str, revision: Optional[str]) -> bool: if Path(model).exists(): @@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # hf_hub. This will fail in offline mode. # Call HF to check if the file exists - # 2 retries and exponential backoff - max_retries = 2 - retry_delay = 2 - for attempt in range(max_retries): - try: - return file_exists(model, - config_name, - revision=revision, - token=HF_TOKEN) - except huggingface_hub.errors.OfflineModeIsEnabled: - # Don't raise in offline mode, - # all we know is that we don't have this - # file cached. - return False - except Exception as e: - logger.error( - "Error checking file existence: %s, retrying %d of %d", e, - attempt + 1, max_retries) - if attempt == max_retries - 1: - logger.error("Error checking file existence: %s", e) - raise - time.sleep(retry_delay) - retry_delay *= 2 - continue - return False + return file_exists(str(model), + config_name, + revision=revision, + token=HF_TOKEN) def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -208,32 +248,7 @@ def get_config( revision=revision): config_format = ConfigFormat.MISTRAL else: - # If we're in offline mode and found no valid config format, then - # raise an offline mode error to indicate to the user that they - # don't have files cached and may need to go online. - # This is conveniently triggered by calling file_exists(). - - # Call HF to check if the file exists - # 2 retries and exponential backoff - max_retries = 2 - retry_delay = 2 - for attempt in range(max_retries): - try: - file_exists(model, - HF_CONFIG_NAME, - revision=revision, - token=HF_TOKEN) - except Exception as e: - logger.error( - "Error checking file existence: %s, retrying %d of %d", - e, attempt + 1, max_retries) - if attempt == max_retries: - logger.error("Error checking file existence: %s", e) - raise e - time.sleep(retry_delay) - retry_delay *= 2 - - raise ValueError(f"No supported config format found in {model}") + raise ValueError(f"No supported config format found in {model}.") if config_format == ConfigFormat.HF: config_dict, _ = PretrainedConfig.get_config_dict( @@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str, file_name=file_name, revision=revision) - if file_path is None and file_or_path_exists( - model=model, config_name=file_name, revision=revision): + if file_path is None: try: hf_hub_file = hf_hub_download(model, file_name, revision=revision) + except huggingface_hub.errors.OfflineModeIsEnabled: + return None except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e: logger.debug("File or repository not found in hf_hub_download", e) @@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str, return None +@cache def get_pooling_config(model: str, revision: Optional[str] = 'main'): """ This function gets the pooling and normalize @@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): if modules_dict is None: return None + logger.info("Found sentence-transformers modules configuration.") + pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None) @@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): if pooling_type_name is not None: pooling_type_name = get_pooling_config_name(pooling_type_name) + logger.info("Found pooling configuration.") return {"pooling_type": pooling_type_name, "normalize": normalize} return None @@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: return None +@cache def get_sentence_transformer_tokenizer_config(model: str, revision: Optional[str] = 'main' ): @@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str, if not encoder_dict: return None + logger.info("Found sentence-transformers tokenize configuration.") + if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")): return encoder_dict return None