mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:05:02 +08:00
[Chore]: Reorganize model repo operating functions in transformers_utils (#29680)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
6f9d81d03b
commit
f946a8d743
@ -8,7 +8,7 @@ from unittest.mock import MagicMock, call, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.transformers_utils.config import list_filtered_repo_files
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -44,7 +44,7 @@ def test_list_filtered_repo_files(
|
|||||||
|
|
||||||
# Patch list_repo_files called by fn
|
# Patch list_repo_files called by fn
|
||||||
with patch(
|
with patch(
|
||||||
"vllm.transformers_utils.config.list_repo_files",
|
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||||
MagicMock(return_value=_glob_path()),
|
MagicMock(return_value=_glob_path()),
|
||||||
) as mock_list_repo_files:
|
) as mock_list_repo_files:
|
||||||
out_files = sorted(
|
out_files = sorted(
|
||||||
|
|||||||
@ -83,10 +83,10 @@ from vllm.platforms import CpuArchEnum, current_platform
|
|||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
|
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
get_model_path,
|
|
||||||
is_interleaved,
|
is_interleaved,
|
||||||
maybe_override_with_speculators,
|
maybe_override_with_speculators,
|
||||||
)
|
)
|
||||||
|
from vllm.transformers_utils.repo_utils import get_model_path
|
||||||
from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
|
from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import list_filtered_repo_files
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
get_hf_file_bytes,
|
|
||||||
try_get_dense_modules,
|
try_get_dense_modules,
|
||||||
)
|
)
|
||||||
|
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
|
||||||
|
|
||||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||||
|
|
||||||
|
|||||||
@ -1,30 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import fnmatch
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from functools import cache, partial
|
from functools import cache, partial
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypeAlias, TypeVar
|
from typing import Any, Literal, TypeAlias
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
get_safetensors_metadata,
|
get_safetensors_metadata,
|
||||||
hf_hub_download,
|
|
||||||
try_to_load_from_cache,
|
|
||||||
)
|
|
||||||
from huggingface_hub import list_repo_files as hf_list_repo_files
|
|
||||||
from huggingface_hub.utils import (
|
|
||||||
EntryNotFoundError,
|
|
||||||
HfHubHTTPError,
|
|
||||||
LocalEntryNotFoundError,
|
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
)
|
)
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
@ -40,6 +27,14 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||||
|
from vllm.transformers_utils.repo_utils import (
|
||||||
|
_get_hf_token,
|
||||||
|
file_or_path_exists,
|
||||||
|
get_hf_file_to_dict,
|
||||||
|
list_repo_files,
|
||||||
|
try_get_local_file,
|
||||||
|
with_retry,
|
||||||
|
)
|
||||||
from vllm.transformers_utils.utils import (
|
from vllm.transformers_utils.utils import (
|
||||||
check_gguf_file,
|
check_gguf_file,
|
||||||
is_gguf,
|
is_gguf,
|
||||||
@ -58,21 +53,6 @@ MISTRAL_CONFIG_NAME = "params.json"
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_hf_token() -> str | None:
|
|
||||||
"""
|
|
||||||
Get the HuggingFace token from environment variable.
|
|
||||||
|
|
||||||
Returns None if the token is not set, is an empty string,
|
|
||||||
or contains only whitespace.
|
|
||||||
This follows the same pattern as huggingface_hub library which
|
|
||||||
treats empty string tokens as None to avoid authentication errors.
|
|
||||||
"""
|
|
||||||
token = os.getenv("HF_TOKEN")
|
|
||||||
if token and token.strip():
|
|
||||||
return token
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LazyConfigDict(dict):
|
class LazyConfigDict(dict):
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(value := super().__getitem__(key), type):
|
if isinstance(value := super().__getitem__(key), type):
|
||||||
@ -308,143 +288,6 @@ def register_config_parser(config_format: str):
|
|||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
_R = TypeVar("_R")
|
|
||||||
|
|
||||||
|
|
||||||
def with_retry(
|
|
||||||
func: Callable[[], _R],
|
|
||||||
log_msg: str,
|
|
||||||
max_retries: int = 2,
|
|
||||||
retry_delay: int = 2,
|
|
||||||
) -> _R:
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except Exception as e:
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
logger.error("%s: %s", log_msg, e)
|
|
||||||
raise
|
|
||||||
logger.error(
|
|
||||||
"%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries
|
|
||||||
)
|
|
||||||
time.sleep(retry_delay)
|
|
||||||
retry_delay *= 2
|
|
||||||
|
|
||||||
raise AssertionError("Should not be reached")
|
|
||||||
|
|
||||||
|
|
||||||
# @cache doesn't cache exceptions
|
|
||||||
@cache
|
|
||||||
def list_repo_files(
|
|
||||||
repo_id: str,
|
|
||||||
*,
|
|
||||||
revision: str | None = None,
|
|
||||||
repo_type: str | None = None,
|
|
||||||
token: str | bool | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
def lookup_files() -> list[str]:
|
|
||||||
# directly list files if model is local
|
|
||||||
if (local_path := Path(repo_id)).exists():
|
|
||||||
return [
|
|
||||||
str(file.relative_to(local_path))
|
|
||||||
for file in local_path.rglob("*")
|
|
||||||
if file.is_file()
|
|
||||||
]
|
|
||||||
# if model is remote, use hf_hub api to list files
|
|
||||||
try:
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
|
||||||
from vllm.transformers_utils.utils import modelscope_list_repo_files
|
|
||||||
|
|
||||||
return modelscope_list_repo_files(
|
|
||||||
repo_id,
|
|
||||||
revision=revision,
|
|
||||||
token=os.getenv("MODELSCOPE_API_TOKEN", None),
|
|
||||||
)
|
|
||||||
return hf_list_repo_files(
|
|
||||||
repo_id, revision=revision, repo_type=repo_type, token=token
|
|
||||||
)
|
|
||||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
|
||||||
# Don't raise in offline mode,
|
|
||||||
# all we know is that we don't have this
|
|
||||||
# file cached.
|
|
||||||
return []
|
|
||||||
|
|
||||||
return with_retry(lookup_files, "Error retrieving file list")
|
|
||||||
|
|
||||||
|
|
||||||
def list_filtered_repo_files(
|
|
||||||
model_name_or_path: str,
|
|
||||||
allow_patterns: list[str],
|
|
||||||
revision: str | None = None,
|
|
||||||
repo_type: str | None = None,
|
|
||||||
token: str | bool | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
try:
|
|
||||||
all_files = list_repo_files(
|
|
||||||
repo_id=model_name_or_path,
|
|
||||||
revision=revision,
|
|
||||||
token=token,
|
|
||||||
repo_type=repo_type,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.error(
|
|
||||||
"Error retrieving file list. Please ensure your `model_name_or_path`"
|
|
||||||
"`repo_type`, `token` and `revision` arguments are correctly set. "
|
|
||||||
"Returning an empty list."
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
file_list = []
|
|
||||||
# Filter patterns on filenames
|
|
||||||
for pattern in allow_patterns:
|
|
||||||
file_list.extend(
|
|
||||||
[
|
|
||||||
file
|
|
||||||
for file in all_files
|
|
||||||
if fnmatch.fnmatch(os.path.basename(file), pattern)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def file_exists(
|
|
||||||
repo_id: str,
|
|
||||||
file_name: str,
|
|
||||||
*,
|
|
||||||
repo_type: str | None = None,
|
|
||||||
revision: str | None = None,
|
|
||||||
token: str | bool | None = None,
|
|
||||||
) -> bool:
|
|
||||||
file_list = list_repo_files(
|
|
||||||
repo_id, repo_type=repo_type, revision=revision, token=token
|
|
||||||
)
|
|
||||||
return file_name in file_list
|
|
||||||
|
|
||||||
|
|
||||||
# In offline mode the result can be a false negative
|
|
||||||
def file_or_path_exists(
|
|
||||||
model: str | Path, config_name: str, revision: str | None
|
|
||||||
) -> bool:
|
|
||||||
if (local_path := Path(model)).exists():
|
|
||||||
return (local_path / config_name).is_file()
|
|
||||||
|
|
||||||
# Offline mode support: Check if config file is cached already
|
|
||||||
cached_filepath = try_to_load_from_cache(
|
|
||||||
repo_id=model, filename=config_name, revision=revision
|
|
||||||
)
|
|
||||||
if isinstance(cached_filepath, str):
|
|
||||||
# The config file exists in cache- we can continue trying to load
|
|
||||||
return True
|
|
||||||
|
|
||||||
# NB: file_exists will only check for the existence of the config file on
|
|
||||||
# hf_hub. This will fail in offline mode.
|
|
||||||
|
|
||||||
# Call HF to check if the file exists
|
|
||||||
return file_exists(
|
|
||||||
str(model), config_name, revision=revision, token=_get_hf_token()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
|
def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
|
||||||
"""Some models may have no rope_theta in their config but still use RoPE.
|
"""Some models may have no rope_theta in their config but still use RoPE.
|
||||||
This function sets a default rope_theta if it's missing."""
|
This function sets a default rope_theta if it's missing."""
|
||||||
@ -836,72 +679,6 @@ def get_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def try_get_local_file(
|
|
||||||
model: str | Path, file_name: str, revision: str | None = "main"
|
|
||||||
) -> Path | None:
|
|
||||||
file_path = Path(model) / file_name
|
|
||||||
if file_path.is_file():
|
|
||||||
return file_path
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
cached_filepath = try_to_load_from_cache(
|
|
||||||
repo_id=model, filename=file_name, revision=revision
|
|
||||||
)
|
|
||||||
if isinstance(cached_filepath, str):
|
|
||||||
return Path(cached_filepath)
|
|
||||||
except ValueError:
|
|
||||||
...
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_hf_file_to_dict(
|
|
||||||
file_name: str, model: str | Path, revision: str | None = "main"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Downloads a file from the Hugging Face Hub and returns
|
|
||||||
its contents as a dictionary.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- file_name (str): The name of the file to download.
|
|
||||||
- model (str): The name of the model on the Hugging Face Hub.
|
|
||||||
- revision (str): The specific version of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- config_dict (dict): A dictionary containing
|
|
||||||
the contents of the downloaded file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
|
|
||||||
|
|
||||||
if file_path is None:
|
|
||||||
try:
|
|
||||||
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
|
||||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
|
||||||
return None
|
|
||||||
except (
|
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
EntryNotFoundError,
|
|
||||||
LocalEntryNotFoundError,
|
|
||||||
) as e:
|
|
||||||
logger.debug("File or repository not found in hf_hub_download", e)
|
|
||||||
return None
|
|
||||||
except HfHubHTTPError as e:
|
|
||||||
logger.warning(
|
|
||||||
"Cannot connect to Hugging Face Hub. Skipping file download for '%s':",
|
|
||||||
file_name,
|
|
||||||
exc_info=e,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
file_path = Path(hf_hub_file)
|
|
||||||
|
|
||||||
if file_path is not None and file_path.is_file():
|
|
||||||
with open(file_path) as file:
|
|
||||||
return json.load(file)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
|
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
|
||||||
"""
|
"""
|
||||||
@ -1316,41 +1093,3 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return max_position_embeddings
|
return max_position_embeddings
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model: str | Path, revision: str | None = None):
|
|
||||||
if os.path.exists(model):
|
|
||||||
return model
|
|
||||||
assert huggingface_hub.constants.HF_HUB_OFFLINE
|
|
||||||
common_kwargs = {
|
|
||||||
"local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
||||||
"revision": revision,
|
|
||||||
}
|
|
||||||
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
|
||||||
|
|
||||||
return snapshot_download(model_id=model, **common_kwargs)
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
return snapshot_download(repo_id=model, **common_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hf_file_bytes(
|
|
||||||
file_name: str, model: str | Path, revision: str | None = "main"
|
|
||||||
) -> bytes | None:
|
|
||||||
"""Get file contents from HuggingFace repository as bytes."""
|
|
||||||
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
|
|
||||||
|
|
||||||
if file_path is None:
|
|
||||||
hf_hub_file = hf_hub_download(
|
|
||||||
model, file_name, revision=revision, token=_get_hf_token()
|
|
||||||
)
|
|
||||||
file_path = Path(hf_hub_file)
|
|
||||||
|
|
||||||
if file_path is not None and file_path.is_file():
|
|
||||||
with open(file_path, "rb") as file:
|
|
||||||
return file.read()
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from gguf.constants import Keys, VisionProjectorType
|
|||||||
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
|
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import list_filtered_repo_files
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
287
vllm/transformers_utils/repo_utils.py
Normal file
287
vllm/transformers_utils/repo_utils.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Utilities for model repo interaction."""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
from huggingface_hub import (
|
||||||
|
hf_hub_download,
|
||||||
|
try_to_load_from_cache,
|
||||||
|
)
|
||||||
|
from huggingface_hub import list_repo_files as hf_list_repo_files
|
||||||
|
from huggingface_hub.utils import (
|
||||||
|
EntryNotFoundError,
|
||||||
|
HfHubHTTPError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hf_token() -> str | None:
|
||||||
|
"""
|
||||||
|
Get the HuggingFace token from environment variable.
|
||||||
|
|
||||||
|
Returns None if the token is not set, is an empty string,
|
||||||
|
or contains only whitespace.
|
||||||
|
This follows the same pattern as huggingface_hub library which
|
||||||
|
treats empty string tokens as None to avoid authentication errors.
|
||||||
|
"""
|
||||||
|
token = os.getenv("HF_TOKEN")
|
||||||
|
if token and token.strip():
|
||||||
|
return token
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
|
|
||||||
|
def with_retry(
|
||||||
|
func: Callable[[], _R],
|
||||||
|
log_msg: str,
|
||||||
|
max_retries: int = 2,
|
||||||
|
retry_delay: int = 2,
|
||||||
|
) -> _R:
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
logger.error("%s: %s", log_msg, e)
|
||||||
|
raise
|
||||||
|
logger.error(
|
||||||
|
"%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries
|
||||||
|
)
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
retry_delay *= 2
|
||||||
|
|
||||||
|
raise AssertionError("Should not be reached")
|
||||||
|
|
||||||
|
|
||||||
|
# @cache doesn't cache exceptions
|
||||||
|
@cache
|
||||||
|
def list_repo_files(
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
revision: str | None = None,
|
||||||
|
repo_type: str | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
def lookup_files() -> list[str]:
|
||||||
|
# directly list files if model is local
|
||||||
|
if (local_path := Path(repo_id)).exists():
|
||||||
|
return [
|
||||||
|
str(file.relative_to(local_path))
|
||||||
|
for file in local_path.rglob("*")
|
||||||
|
if file.is_file()
|
||||||
|
]
|
||||||
|
# if model is remote, use hf_hub api to list files
|
||||||
|
try:
|
||||||
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
|
from vllm.transformers_utils.utils import modelscope_list_repo_files
|
||||||
|
|
||||||
|
return modelscope_list_repo_files(
|
||||||
|
repo_id,
|
||||||
|
revision=revision,
|
||||||
|
token=os.getenv("MODELSCOPE_API_TOKEN", None),
|
||||||
|
)
|
||||||
|
return hf_list_repo_files(
|
||||||
|
repo_id, revision=revision, repo_type=repo_type, token=token
|
||||||
|
)
|
||||||
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
|
# Don't raise in offline mode,
|
||||||
|
# all we know is that we don't have this
|
||||||
|
# file cached.
|
||||||
|
return []
|
||||||
|
|
||||||
|
return with_retry(lookup_files, "Error retrieving file list")
|
||||||
|
|
||||||
|
|
||||||
|
def list_filtered_repo_files(
|
||||||
|
model_name_or_path: str,
|
||||||
|
allow_patterns: list[str],
|
||||||
|
revision: str | None = None,
|
||||||
|
repo_type: str | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
try:
|
||||||
|
all_files = list_repo_files(
|
||||||
|
repo_id=model_name_or_path,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
"Error retrieving file list. Please ensure your `model_name_or_path`"
|
||||||
|
"`repo_type`, `token` and `revision` arguments are correctly set. "
|
||||||
|
"Returning an empty list."
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
file_list = []
|
||||||
|
# Filter patterns on filenames
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
file_list.extend(
|
||||||
|
[
|
||||||
|
file
|
||||||
|
for file in all_files
|
||||||
|
if fnmatch.fnmatch(os.path.basename(file), pattern)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return file_list
|
||||||
|
|
||||||
|
|
||||||
|
def file_exists(
|
||||||
|
repo_id: str,
|
||||||
|
file_name: str,
|
||||||
|
*,
|
||||||
|
repo_type: str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
) -> bool:
|
||||||
|
file_list = list_repo_files(
|
||||||
|
repo_id, repo_type=repo_type, revision=revision, token=token
|
||||||
|
)
|
||||||
|
return file_name in file_list
|
||||||
|
|
||||||
|
|
||||||
|
# In offline mode the result can be a false negative
|
||||||
|
def file_or_path_exists(
|
||||||
|
model: str | Path, config_name: str, revision: str | None
|
||||||
|
) -> bool:
|
||||||
|
if (local_path := Path(model)).exists():
|
||||||
|
return (local_path / config_name).is_file()
|
||||||
|
|
||||||
|
# Offline mode support: Check if config file is cached already
|
||||||
|
cached_filepath = try_to_load_from_cache(
|
||||||
|
repo_id=model, filename=config_name, revision=revision
|
||||||
|
)
|
||||||
|
if isinstance(cached_filepath, str):
|
||||||
|
# The config file exists in cache- we can continue trying to load
|
||||||
|
return True
|
||||||
|
|
||||||
|
# NB: file_exists will only check for the existence of the config file on
|
||||||
|
# hf_hub. This will fail in offline mode.
|
||||||
|
|
||||||
|
# Call HF to check if the file exists
|
||||||
|
return file_exists(
|
||||||
|
str(model), config_name, revision=revision, token=_get_hf_token()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model: str | Path, revision: str | None = None):
|
||||||
|
if os.path.exists(model):
|
||||||
|
return model
|
||||||
|
assert huggingface_hub.constants.HF_HUB_OFFLINE
|
||||||
|
common_kwargs = {
|
||||||
|
"local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
"revision": revision,
|
||||||
|
}
|
||||||
|
|
||||||
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
|
||||||
|
return snapshot_download(model_id=model, **common_kwargs)
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
return snapshot_download(repo_id=model, **common_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_file_bytes(
|
||||||
|
file_name: str, model: str | Path, revision: str | None = "main"
|
||||||
|
) -> bytes | None:
|
||||||
|
"""Get file contents from HuggingFace repository as bytes."""
|
||||||
|
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
|
||||||
|
|
||||||
|
if file_path is None:
|
||||||
|
hf_hub_file = hf_hub_download(
|
||||||
|
model, file_name, revision=revision, token=_get_hf_token()
|
||||||
|
)
|
||||||
|
file_path = Path(hf_hub_file)
|
||||||
|
|
||||||
|
if file_path is not None and file_path.is_file():
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_local_file(
|
||||||
|
model: str | Path, file_name: str, revision: str | None = "main"
|
||||||
|
) -> Path | None:
|
||||||
|
file_path = Path(model) / file_name
|
||||||
|
if file_path.is_file():
|
||||||
|
return file_path
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cached_filepath = try_to_load_from_cache(
|
||||||
|
repo_id=model, filename=file_name, revision=revision
|
||||||
|
)
|
||||||
|
if isinstance(cached_filepath, str):
|
||||||
|
return Path(cached_filepath)
|
||||||
|
except ValueError:
|
||||||
|
...
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_file_to_dict(
|
||||||
|
file_name: str, model: str | Path, revision: str | None = "main"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Downloads a file from the Hugging Face Hub and returns
|
||||||
|
its contents as a dictionary.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file_name (str): The name of the file to download.
|
||||||
|
- model (str): The name of the model on the Hugging Face Hub.
|
||||||
|
- revision (str): The specific version of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- config_dict (dict): A dictionary containing
|
||||||
|
the contents of the downloaded file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
|
||||||
|
|
||||||
|
if file_path is None:
|
||||||
|
try:
|
||||||
|
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
||||||
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
|
return None
|
||||||
|
except (
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
) as e:
|
||||||
|
logger.debug("File or repository not found in hf_hub_download", e)
|
||||||
|
return None
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Cannot connect to Hugging Face Hub. Skipping file download for '%s':",
|
||||||
|
file_name,
|
||||||
|
exc_info=e,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
file_path = Path(hf_hub_file)
|
||||||
|
|
||||||
|
if file_path is not None and file_path.is_file():
|
||||||
|
with open(file_path) as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
return None
|
||||||
@ -15,11 +15,9 @@ from typing_extensions import assert_never
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
|
||||||
get_sentence_transformer_tokenizer_config,
|
|
||||||
list_filtered_repo_files,
|
|
||||||
)
|
|
||||||
from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf
|
from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf
|
||||||
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||||
from vllm.transformers_utils.utils import (
|
from vllm.transformers_utils.utils import (
|
||||||
check_gguf_file,
|
check_gguf_file,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user