mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:35:00 +08:00
[Core] Support offline use of local cache for models (#4374)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Travis Johnson <tjohnson31415@gmail.com>
This commit is contained in:
parent
81661da7b2
commit
d6e520e170
@ -1,9 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
import pytest
|
import pytest
|
||||||
|
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||||
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
download_weights_from_hf, enable_hf_transfer)
|
||||||
|
|
||||||
|
|
||||||
def test_hf_transfer_auto_activation():
|
def test_hf_transfer_auto_activation():
|
||||||
@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation():
|
|||||||
HF_TRANFER_ACTIVE)
|
HF_TRANFER_ACTIVE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_weights_from_hf():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# assert LocalEntryNotFoundError error is thrown
|
||||||
|
# if offline is set and model is not cached
|
||||||
|
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
||||||
|
with pytest.raises(LocalEntryNotFoundError):
|
||||||
|
download_weights_from_hf("facebook/opt-125m",
|
||||||
|
allow_patterns=["*.safetensors", "*.bin"],
|
||||||
|
cache_dir=tmpdir)
|
||||||
|
|
||||||
|
# download the model
|
||||||
|
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||||
|
download_weights_from_hf("facebook/opt-125m",
|
||||||
|
allow_patterns=["*.safetensors", "*.bin"],
|
||||||
|
cache_dir=tmpdir)
|
||||||
|
|
||||||
|
# now it should work offline
|
||||||
|
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
||||||
|
assert download_weights_from_hf(
|
||||||
|
"facebook/opt-125m",
|
||||||
|
allow_patterns=["*.safetensors", "*.bin"],
|
||||||
|
cache_dir=tmpdir) is not None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_hf_transfer_auto_activation()
|
test_hf_transfer_auto_activation()
|
||||||
|
test_download_weights_from_hf()
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import os
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -131,7 +132,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
model_path = snapshot_download(
|
model_path = snapshot_download(
|
||||||
model_id=model,
|
model_id=model,
|
||||||
cache_dir=self.load_config.download_dir,
|
cache_dir=self.load_config.download_dir,
|
||||||
revision=revision)
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_path = model
|
model_path = model
|
||||||
return model_path
|
return model_path
|
||||||
|
|||||||
@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
|
|||||||
if not is_local:
|
if not is_local:
|
||||||
# Download the config files.
|
# Download the config files.
|
||||||
with get_lock(model_name_or_path, load_config.download_dir):
|
with get_lock(model_name_or_path, load_config.download_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(
|
||||||
revision=model_config.revision,
|
model_name_or_path,
|
||||||
allow_patterns="*.json",
|
revision=model_config.revision,
|
||||||
cache_dir=load_config.download_dir,
|
allow_patterns="*.json",
|
||||||
tqdm_class=DisabledTqdm)
|
cache_dir=load_config.download_dir,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
tqdm_class=DisabledTqdm,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
|
|
||||||
@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
|
|||||||
return quant_cls.from_config(config)
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
def download_weights_from_hf(model_name_or_path: str,
|
def download_weights_from_hf(
|
||||||
cache_dir: Optional[str],
|
model_name_or_path: str,
|
||||||
allow_patterns: List[str],
|
cache_dir: Optional[str],
|
||||||
revision: Optional[str] = None) -> str:
|
allow_patterns: List[str],
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""Download model weights from Hugging Face Hub.
|
"""Download model weights from Hugging Face Hub.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name_or_path (str): The model name or path.
|
model_name_or_path (str): The model name or path.
|
||||||
cache_dir (Optional[str]): The cache directory to store the model
|
cache_dir (Optional[str]): The cache directory to store the model
|
||||||
@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
|
|||||||
Returns:
|
Returns:
|
||||||
str: The path to the downloaded model weights.
|
str: The path to the downloaded model weights.
|
||||||
"""
|
"""
|
||||||
# Before we download we look at that is available:
|
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||||
fs = HfFileSystem()
|
# Before we download we look at that is available:
|
||||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
fs = HfFileSystem()
|
||||||
|
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||||
|
|
||||||
# depending on what is available we download different things
|
# depending on what is available we download different things
|
||||||
for pattern in allow_patterns:
|
for pattern in allow_patterns:
|
||||||
matching = fnmatch.filter(file_list, pattern)
|
matching = fnmatch.filter(file_list, pattern)
|
||||||
if len(matching) > 0:
|
if len(matching) > 0:
|
||||||
allow_patterns = [pattern]
|
allow_patterns = [pattern]
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info("Using model weights format %s", allow_patterns)
|
logger.info("Using model weights format %s", allow_patterns)
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(
|
||||||
allow_patterns=allow_patterns,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
allow_patterns=allow_patterns,
|
||||||
tqdm_class=DisabledTqdm,
|
cache_dir=cache_dir,
|
||||||
revision=revision)
|
tqdm_class=DisabledTqdm,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
)
|
||||||
return hf_folder
|
return hf_folder
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
@ -76,6 +77,7 @@ def get_tokenizer(
|
|||||||
model_id=tokenizer_name,
|
model_id=tokenizer_name,
|
||||||
cache_dir=download_dir,
|
cache_dir=download_dir,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
# Ignore weights - we only need the tokenizer.
|
# Ignore weights - we only need the tokenizer.
|
||||||
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
|
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||||
tokenizer_name = tokenizer_path
|
tokenizer_name = tokenizer_path
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user