[Core] Support offline use of local cache for models (#4374)

Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Travis Johnson <tjohnson31415@gmail.com>
This commit is contained in:
Prashant Gupta 2024-04-27 09:59:55 -07:00 committed by GitHub
parent 81661da7b2
commit d6e520e170
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 26 deletions

View File

@ -1,9 +1,12 @@
import os import os
import tempfile
import huggingface_hub.constants import huggingface_hub.constants
import pytest import pytest
from huggingface_hub.utils import LocalEntryNotFoundError
from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, enable_hf_transfer)
def test_hf_transfer_auto_activation(): def test_hf_transfer_auto_activation():
@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation():
HF_TRANFER_ACTIVE) HF_TRANFER_ACTIVE)
def test_download_weights_from_hf():
with tempfile.TemporaryDirectory() as tmpdir:
# assert LocalEntryNotFoundError error is thrown
# if offline is set and model is not cached
huggingface_hub.constants.HF_HUB_OFFLINE = True
with pytest.raises(LocalEntryNotFoundError):
download_weights_from_hf("facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir)
# download the model
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir)
# now it should work offline
huggingface_hub.constants.HF_HUB_OFFLINE = True
assert download_weights_from_hf(
"facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir) is not None
if __name__ == "__main__": if __name__ == "__main__":
test_hf_transfer_auto_activation() test_hf_transfer_auto_activation()
test_download_weights_from_hf()

View File

@ -5,6 +5,7 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Type from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub
import torch import torch
from torch import nn from torch import nn
@ -131,7 +132,9 @@ class DefaultModelLoader(BaseModelLoader):
model_path = snapshot_download( model_path = snapshot_download(
model_id=model, model_id=model,
cache_dir=self.load_config.download_dir, cache_dir=self.load_config.download_dir,
revision=revision) local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
)
else: else:
model_path = model model_path = model
return model_path return model_path

View File

@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if not is_local: if not is_local:
# Download the config files. # Download the config files.
with get_lock(model_name_or_path, load_config.download_dir): with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(
revision=model_config.revision, model_name_or_path,
allow_patterns="*.json", revision=model_config.revision,
cache_dir=load_config.download_dir, allow_patterns="*.json",
tqdm_class=DisabledTqdm) cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config) return quant_cls.from_config(config)
def download_weights_from_hf(model_name_or_path: str, def download_weights_from_hf(
cache_dir: Optional[str], model_name_or_path: str,
allow_patterns: List[str], cache_dir: Optional[str],
revision: Optional[str] = None) -> str: allow_patterns: List[str],
revision: Optional[str] = None,
) -> str:
"""Download model weights from Hugging Face Hub. """Download model weights from Hugging Face Hub.
Args: Args:
model_name_or_path (str): The model name or path. model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model cache_dir (Optional[str]): The cache directory to store the model
@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns: Returns:
str: The path to the downloaded model weights. str: The path to the downloaded model weights.
""" """
# Before we download we look at that is available: if not huggingface_hub.constants.HF_HUB_OFFLINE:
fs = HfFileSystem() # Before we download we look at that is available:
file_list = fs.ls(model_name_or_path, detail=False, revision=revision) fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things # depending on what is available we download different things
for pattern in allow_patterns: for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern) matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0: if len(matching) > 0:
allow_patterns = [pattern] allow_patterns = [pattern]
break break
logger.info("Using model weights format %s", allow_patterns) logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(
allow_patterns=allow_patterns, model_name_or_path,
cache_dir=cache_dir, allow_patterns=allow_patterns,
tqdm_class=DisabledTqdm, cache_dir=cache_dir,
revision=revision) tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
return hf_folder return hf_folder

View File

@ -1,6 +1,7 @@
import os import os
from typing import Optional, Union from typing import Optional, Union
import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
@ -76,6 +77,7 @@ def get_tokenizer(
model_id=tokenizer_name, model_id=tokenizer_name,
cache_dir=download_dir, cache_dir=download_dir,
revision=revision, revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
# Ignore weights - we only need the tokenizer. # Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path tokenizer_name = tokenizer_path