[Chore]: Reorganize model repo operating functions in transformers_utils (#29680)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-29 00:46:51 +08:00 committed by GitHub
parent 6f9d81d03b
commit f946a8d743
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 304 additions and 280 deletions

View File

@ -8,7 +8,7 @@ from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.config import list_filtered_repo_files
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
@pytest.mark.parametrize(
@ -44,7 +44,7 @@ def test_list_filtered_repo_files(
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.config.list_repo_files",
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(

View File

@ -83,10 +83,10 @@ from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
from vllm.transformers_utils.config import (
get_model_path,
is_interleaved,
maybe_override_with_speculators,
)
from vllm.transformers_utils.repo_utils import get_model_path
from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes

View File

@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import list_filtered_repo_files
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)

View File

@ -14,9 +14,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from vllm.transformers_utils.config import (
get_hf_file_bytes,
try_get_dense_modules,
)
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
from .interfaces_base import VllmModelForPooling, is_pooling_model

View File

@ -1,30 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
import json
import os
import time
from collections.abc import Callable
from dataclasses import asdict
from functools import cache, partial
from importlib.metadata import version
from pathlib import Path
from typing import Any, Literal, TypeAlias, TypeVar
from typing import Any, Literal, TypeAlias
import huggingface_hub
from huggingface_hub import (
get_safetensors_metadata,
hf_hub_download,
try_to_load_from_cache,
)
from huggingface_hub import list_repo_files as hf_list_repo_files
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from packaging.version import Version
from transformers import GenerationConfig, PretrainedConfig
@ -40,6 +27,14 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.repo_utils import (
_get_hf_token,
file_or_path_exists,
get_hf_file_to_dict,
list_repo_files,
try_get_local_file,
with_retry,
)
from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
@ -58,21 +53,6 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__)
def _get_hf_token() -> str | None:
"""
Get the HuggingFace token from environment variable.
Returns None if the token is not set, is an empty string,
or contains only whitespace.
This follows the same pattern as huggingface_hub library which
treats empty string tokens as None to avoid authentication errors.
"""
token = os.getenv("HF_TOKEN")
if token and token.strip():
return token
return None
class LazyConfigDict(dict):
def __getitem__(self, key):
if isinstance(value := super().__getitem__(key), type):
@ -308,143 +288,6 @@ def register_config_parser(config_format: str):
return _wrapper
_R = TypeVar("_R")
def with_retry(
func: Callable[[], _R],
log_msg: str,
max_retries: int = 2,
retry_delay: int = 2,
) -> _R:
for attempt in range(max_retries):
try:
return func()
except Exception as e:
if attempt == max_retries - 1:
logger.error("%s: %s", log_msg, e)
raise
logger.error(
"%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries
)
time.sleep(retry_delay)
retry_delay *= 2
raise AssertionError("Should not be reached")
# @cache doesn't cache exceptions
@cache
def list_repo_files(
repo_id: str,
*,
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
def lookup_files() -> list[str]:
# directly list files if model is local
if (local_path := Path(repo_id)).exists():
return [
str(file.relative_to(local_path))
for file in local_path.rglob("*")
if file.is_file()
]
# if model is remote, use hf_hub api to list files
try:
if envs.VLLM_USE_MODELSCOPE:
from vllm.transformers_utils.utils import modelscope_list_repo_files
return modelscope_list_repo_files(
repo_id,
revision=revision,
token=os.getenv("MODELSCOPE_API_TOKEN", None),
)
return hf_list_repo_files(
repo_id, revision=revision, repo_type=repo_type, token=token
)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode,
# all we know is that we don't have this
# file cached.
return []
return with_retry(lookup_files, "Error retrieving file list")
def list_filtered_repo_files(
model_name_or_path: str,
allow_patterns: list[str],
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
try:
all_files = list_repo_files(
repo_id=model_name_or_path,
revision=revision,
token=token,
repo_type=repo_type,
)
except Exception:
logger.error(
"Error retrieving file list. Please ensure your `model_name_or_path`"
"`repo_type`, `token` and `revision` arguments are correctly set. "
"Returning an empty list."
)
return []
file_list = []
# Filter patterns on filenames
for pattern in allow_patterns:
file_list.extend(
[
file
for file in all_files
if fnmatch.fnmatch(os.path.basename(file), pattern)
]
)
return file_list
def file_exists(
repo_id: str,
file_name: str,
*,
repo_type: str | None = None,
revision: str | None = None,
token: str | bool | None = None,
) -> bool:
file_list = list_repo_files(
repo_id, repo_type=repo_type, revision=revision, token=token
)
return file_name in file_list
# In offline mode the result can be a false negative
def file_or_path_exists(
model: str | Path, config_name: str, revision: str | None
) -> bool:
if (local_path := Path(model)).exists():
return (local_path / config_name).is_file()
# Offline mode support: Check if config file is cached already
cached_filepath = try_to_load_from_cache(
repo_id=model, filename=config_name, revision=revision
)
if isinstance(cached_filepath, str):
# The config file exists in cache- we can continue trying to load
return True
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
# Call HF to check if the file exists
return file_exists(
str(model), config_name, revision=revision, token=_get_hf_token()
)
def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
"""Some models may have no rope_theta in their config but still use RoPE.
This function sets a default rope_theta if it's missing."""
@ -836,72 +679,6 @@ def get_config(
return config
def try_get_local_file(
model: str | Path, file_name: str, revision: str | None = "main"
) -> Path | None:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
else:
try:
cached_filepath = try_to_load_from_cache(
repo_id=model, filename=file_name, revision=revision
)
if isinstance(cached_filepath, str):
return Path(cached_filepath)
except ValueError:
...
return None
def get_hf_file_to_dict(
file_name: str, model: str | Path, revision: str | None = "main"
):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
Returns:
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None:
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except huggingface_hub.errors.OfflineModeIsEnabled:
return None
except (
RepositoryNotFoundError,
RevisionNotFoundError,
EntryNotFoundError,
LocalEntryNotFoundError,
) as e:
logger.debug("File or repository not found in hf_hub_download", e)
return None
except HfHubHTTPError as e:
logger.warning(
"Cannot connect to Hugging Face Hub. Skipping file download for '%s':",
file_name,
exc_info=e,
)
return None
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path) as file:
return json.load(file)
return None
@cache
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
"""
@ -1316,41 +1093,3 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
)
return max_position_embeddings
def get_model_path(model: str | Path, revision: str | None = None):
if os.path.exists(model):
return model
assert huggingface_hub.constants.HF_HUB_OFFLINE
common_kwargs = {
"local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
"revision": revision,
}
if envs.VLLM_USE_MODELSCOPE:
from modelscope.hub.snapshot_download import snapshot_download
return snapshot_download(model_id=model, **common_kwargs)
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=model, **common_kwargs)
def get_hf_file_bytes(
file_name: str, model: str | Path, revision: str | None = "main"
) -> bytes | None:
"""Get file contents from HuggingFace repository as bytes."""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None:
hf_hub_file = hf_hub_download(
model, file_name, revision=revision, token=_get_hf_token()
)
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path, "rb") as file:
return file.read()
return None

View File

@ -9,7 +9,7 @@ from gguf.constants import Keys, VisionProjectorType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import list_filtered_repo_files
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)

View File

@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for model repo interaction."""
import fnmatch
import json
import os
import time
from collections.abc import Callable
from functools import cache
from pathlib import Path
from typing import TypeVar
import huggingface_hub
from huggingface_hub import (
hf_hub_download,
try_to_load_from_cache,
)
from huggingface_hub import list_repo_files as hf_list_repo_files
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def _get_hf_token() -> str | None:
"""
Get the HuggingFace token from environment variable.
Returns None if the token is not set, is an empty string,
or contains only whitespace.
This follows the same pattern as huggingface_hub library which
treats empty string tokens as None to avoid authentication errors.
"""
token = os.getenv("HF_TOKEN")
if token and token.strip():
return token
return None
_R = TypeVar("_R")
def with_retry(
func: Callable[[], _R],
log_msg: str,
max_retries: int = 2,
retry_delay: int = 2,
) -> _R:
for attempt in range(max_retries):
try:
return func()
except Exception as e:
if attempt == max_retries - 1:
logger.error("%s: %s", log_msg, e)
raise
logger.error(
"%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries
)
time.sleep(retry_delay)
retry_delay *= 2
raise AssertionError("Should not be reached")
# @cache doesn't cache exceptions
@cache
def list_repo_files(
repo_id: str,
*,
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
def lookup_files() -> list[str]:
# directly list files if model is local
if (local_path := Path(repo_id)).exists():
return [
str(file.relative_to(local_path))
for file in local_path.rglob("*")
if file.is_file()
]
# if model is remote, use hf_hub api to list files
try:
if envs.VLLM_USE_MODELSCOPE:
from vllm.transformers_utils.utils import modelscope_list_repo_files
return modelscope_list_repo_files(
repo_id,
revision=revision,
token=os.getenv("MODELSCOPE_API_TOKEN", None),
)
return hf_list_repo_files(
repo_id, revision=revision, repo_type=repo_type, token=token
)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode,
# all we know is that we don't have this
# file cached.
return []
return with_retry(lookup_files, "Error retrieving file list")
def list_filtered_repo_files(
model_name_or_path: str,
allow_patterns: list[str],
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
try:
all_files = list_repo_files(
repo_id=model_name_or_path,
revision=revision,
token=token,
repo_type=repo_type,
)
except Exception:
logger.error(
"Error retrieving file list. Please ensure your `model_name_or_path`"
"`repo_type`, `token` and `revision` arguments are correctly set. "
"Returning an empty list."
)
return []
file_list = []
# Filter patterns on filenames
for pattern in allow_patterns:
file_list.extend(
[
file
for file in all_files
if fnmatch.fnmatch(os.path.basename(file), pattern)
]
)
return file_list
def file_exists(
repo_id: str,
file_name: str,
*,
repo_type: str | None = None,
revision: str | None = None,
token: str | bool | None = None,
) -> bool:
file_list = list_repo_files(
repo_id, repo_type=repo_type, revision=revision, token=token
)
return file_name in file_list
# In offline mode the result can be a false negative
def file_or_path_exists(
model: str | Path, config_name: str, revision: str | None
) -> bool:
if (local_path := Path(model)).exists():
return (local_path / config_name).is_file()
# Offline mode support: Check if config file is cached already
cached_filepath = try_to_load_from_cache(
repo_id=model, filename=config_name, revision=revision
)
if isinstance(cached_filepath, str):
# The config file exists in cache- we can continue trying to load
return True
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
# Call HF to check if the file exists
return file_exists(
str(model), config_name, revision=revision, token=_get_hf_token()
)
def get_model_path(model: str | Path, revision: str | None = None):
if os.path.exists(model):
return model
assert huggingface_hub.constants.HF_HUB_OFFLINE
common_kwargs = {
"local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
"revision": revision,
}
if envs.VLLM_USE_MODELSCOPE:
from modelscope.hub.snapshot_download import snapshot_download
return snapshot_download(model_id=model, **common_kwargs)
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=model, **common_kwargs)
def get_hf_file_bytes(
file_name: str, model: str | Path, revision: str | None = "main"
) -> bytes | None:
"""Get file contents from HuggingFace repository as bytes."""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None:
hf_hub_file = hf_hub_download(
model, file_name, revision=revision, token=_get_hf_token()
)
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path, "rb") as file:
return file.read()
return None
def try_get_local_file(
model: str | Path, file_name: str, revision: str | None = "main"
) -> Path | None:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
else:
try:
cached_filepath = try_to_load_from_cache(
repo_id=model, filename=file_name, revision=revision
)
if isinstance(cached_filepath, str):
return Path(cached_filepath)
except ValueError:
...
return None
def get_hf_file_to_dict(
file_name: str, model: str | Path, revision: str | None = "main"
):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
Returns:
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None:
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except huggingface_hub.errors.OfflineModeIsEnabled:
return None
except (
RepositoryNotFoundError,
RevisionNotFoundError,
EntryNotFoundError,
LocalEntryNotFoundError,
) as e:
logger.debug("File or repository not found in hf_hub_download", e)
return None
except HfHubHTTPError as e:
logger.warning(
"Cannot connect to Hugging Face Hub. Skipping file download for '%s':",
file_name,
exc_info=e,
)
return None
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path) as file:
return json.load(file)
return None

View File

@ -15,11 +15,9 @@ from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config,
list_filtered_repo_files,
)
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import (
check_gguf_file,