mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
[Bugfix] Fix offline mode when using mistral_common (#9457)
This commit is contained in:
parent
0c9a5258f9
commit
337ed76671
@ -1,50 +1,56 @@
|
||||
"""Tests for HF_HUB_OFFLINE mode"""
|
||||
import importlib
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"model": "facebook/opt-125m",
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.20,
|
||||
"max_model_len": 64,
|
||||
"max_num_batched_tokens": 64,
|
||||
"max_num_seqs": 64,
|
||||
"tensor_parallel_size": 1,
|
||||
},
|
||||
{
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.95,
|
||||
"max_model_len": 64,
|
||||
"max_num_batched_tokens": 64,
|
||||
"max_num_seqs": 64,
|
||||
"tensor_parallel_size": 1,
|
||||
"tokenizer_mode": "mistral",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME,
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.10,
|
||||
enforce_eager=True)
|
||||
def cache_models():
|
||||
# Cache model files first
|
||||
for model_config in MODEL_CONFIGS:
|
||||
LLM(**model_config)
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_offline_mode(llm: LLM, monkeypatch):
|
||||
# we use the llm fixture to ensure the model files are in-cache
|
||||
del llm
|
||||
|
||||
@pytest.mark.usefixtures("cache_models")
|
||||
def test_offline_mode(monkeypatch):
|
||||
# Set HF to offline mode and ensure we can still construct an LLM
|
||||
try:
|
||||
monkeypatch.setenv("HF_HUB_OFFLINE", "1")
|
||||
# Need to re-import huggingface_hub and friends to setup offline mode
|
||||
_re_import_modules()
|
||||
# Cached model files should be used in offline mode
|
||||
LLM(model=MODEL_NAME,
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.20,
|
||||
enforce_eager=True)
|
||||
for model_config in MODEL_CONFIGS:
|
||||
LLM(**model_config)
|
||||
finally:
|
||||
# Reset the environment after the test
|
||||
# NB: Assuming tests are run in online mode
|
||||
|
||||
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
# yapf: disable
|
||||
@ -24,6 +25,26 @@ class Encoding:
|
||||
input_ids: List[int]
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
||||
["models", *repo_id.split("/")]))
|
||||
|
||||
if revision is None:
|
||||
revision_file = os.path.join(repo_cache, "refs", "main")
|
||||
if os.path.isfile(revision_file):
|
||||
with open(revision_file) as file:
|
||||
revision = file.read()
|
||||
|
||||
if revision:
|
||||
revision_dir = os.path.join(repo_cache, "snapshots", revision)
|
||||
if os.path.isdir(revision_dir):
|
||||
return os.listdir(revision_dir)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def find_tokenizer_file(files: List[str]):
|
||||
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
|
||||
|
||||
@ -90,9 +111,16 @@ class MistralTokenizer:
|
||||
@staticmethod
|
||||
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
|
||||
revision: Optional[str]) -> str:
|
||||
api = HfApi()
|
||||
repo_info = api.model_info(tokenizer_name)
|
||||
files = [s.rfilename for s in repo_info.siblings]
|
||||
try:
|
||||
hf_api = HfApi()
|
||||
files = hf_api.list_repo_files(repo_id=tokenizer_name,
|
||||
revision=revision)
|
||||
except ConnectionError as exc:
|
||||
files = list_local_repo_files(repo_id=tokenizer_name,
|
||||
revision=revision)
|
||||
|
||||
if len(files) == 0:
|
||||
raise exc
|
||||
|
||||
filename = find_tokenizer_file(files)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user