mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Fix] fix offline env use local mode path (#22526)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
c6d80a7a96
commit
38217877aa
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for HF_HUB_OFFLINE mode"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
@ -9,6 +10,7 @@ import urllib3
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
@ -108,3 +110,36 @@ def _re_import_modules():
|
||||
# Error this test if reloading a module failed
|
||||
if reload_exception is not None:
|
||||
raise reload_exception
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.usefixtures("cache_models")
|
||||
def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set HF to offline mode and ensure we can still construct an LLM
|
||||
with monkeypatch.context() as m:
|
||||
try:
|
||||
m.setenv("HF_HUB_OFFLINE", "1")
|
||||
m.setenv("VLLM_NO_USAGE_STATS", "1")
|
||||
|
||||
def disable_connect(*args, **kwargs):
|
||||
raise RuntimeError("No http calls allowed")
|
||||
|
||||
m.setattr(
|
||||
urllib3.connection.HTTPConnection,
|
||||
"connect",
|
||||
disable_connect,
|
||||
)
|
||||
m.setattr(
|
||||
urllib3.connection.HTTPSConnection,
|
||||
"connect",
|
||||
disable_connect,
|
||||
)
|
||||
# Need to re-import huggingface_hub
|
||||
# and friends to setup offline mode
|
||||
_re_import_modules()
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
LLM(**dataclasses.asdict(engine_args))
|
||||
finally:
|
||||
# Reset the environment after the test
|
||||
# NB: Assuming tests are run in online mode
|
||||
_re_import_modules()
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
|
||||
Literal, Optional, Type, TypeVar, Union, cast, get_args,
|
||||
get_origin)
|
||||
|
||||
import huggingface_hub
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
@ -39,7 +40,7 @@ from vllm.plugins import load_general_plugins
|
||||
from vllm.ray.lazy_utils import is_ray_initialized
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.config import is_interleaved
|
||||
from vllm.transformers_utils.config import get_model_path, is_interleaved
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, get_ip, is_in_ray_actor)
|
||||
@ -457,6 +458,13 @@ class EngineArgs:
|
||||
# Setup plugins
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
# when use hf offline,replace model id to local model path
|
||||
if huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||
model_id = self.model
|
||||
self.model = get_model_path(self.model, self.revision)
|
||||
logger.info(
|
||||
"HF_HUB_OFFLINE is True, replace model_id [%s] " \
|
||||
"to model_path [%s]",model_id, self.model)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
|
||||
@ -14,7 +14,7 @@ from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
||||
from huggingface_hub import list_repo_files as hf_list_repo_files
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
HFValidationError, LocalEntryNotFoundError,
|
||||
LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError)
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
@ -335,6 +335,7 @@ def maybe_override_with_speculators_target_model(
|
||||
gguf_model_repo = Path(model).parent
|
||||
else:
|
||||
gguf_model_repo = None
|
||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model if gguf_model_repo is None else gguf_model_repo,
|
||||
revision=revision,
|
||||
@ -400,6 +401,7 @@ def get_config(
|
||||
raise ValueError(error_message) from e
|
||||
|
||||
if config_format == ConfigFormat.HF:
|
||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
revision=revision,
|
||||
@ -532,7 +534,7 @@ def try_get_local_file(model: Union[str, Path],
|
||||
revision=revision)
|
||||
if isinstance(cached_filepath, str):
|
||||
return Path(cached_filepath)
|
||||
except HFValidationError:
|
||||
except ValueError:
|
||||
...
|
||||
return None
|
||||
|
||||
@ -908,3 +910,20 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
|
||||
exc_info=e)
|
||||
|
||||
return max_position_embeddings
|
||||
|
||||
|
||||
def get_model_path(model: Union[str, Path], revision: Optional[str] = None):
|
||||
if os.path.exists(model):
|
||||
return model
|
||||
assert huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
common_kwargs = {
|
||||
"local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
return snapshot_download(model_id=model, **common_kwargs)
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
return snapshot_download(repo_id=model, **common_kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user