Prevent unecessary requests to huggingface hub (#12837)

This commit is contained in:
Maximilien de Bayser 2025-02-07 02:37:41 -03:00 committed by GitHub
parent aa375dca9f
commit 6e1fc61f0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 40 deletions

View File

@ -4,6 +4,7 @@ import importlib
import sys import sys
import pytest import pytest
import urllib3
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
@ -28,6 +29,15 @@ MODEL_CONFIGS = [
"tensor_parallel_size": 1, "tensor_parallel_size": 1,
"tokenizer_mode": "mistral", "tokenizer_mode": "mistral",
}, },
{
"model": "sentence-transformers/all-MiniLM-L12-v2",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
] ]
@ -47,6 +57,16 @@ def test_offline_mode(monkeypatch):
# Set HF to offline mode and ensure we can still construct an LLM # Set HF to offline mode and ensure we can still construct an LLM
try: try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1") monkeypatch.setenv("HF_HUB_OFFLINE", "1")
monkeypatch.setenv("VLLM_NO_USAGE_STATS", "1")
def disable_connect(*args, **kwargs):
raise RuntimeError("No http calls allowed")
monkeypatch.setattr(urllib3.connection.HTTPConnection, "connect",
disable_connect)
monkeypatch.setattr(urllib3.connection.HTTPSConnection, "connect",
disable_connect)
# Need to re-import huggingface_hub and friends to setup offline mode # Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules() _re_import_modules()
# Cached model files should be used in offline mode # Cached model files should be used in offline mode
@ -56,6 +76,7 @@ def test_offline_mode(monkeypatch):
# Reset the environment after the test # Reset the environment after the test
# NB: Assuming tests are run in online mode # NB: Assuming tests are run in online mode
monkeypatch.delenv("HF_HUB_OFFLINE") monkeypatch.delenv("HF_HUB_OFFLINE")
monkeypatch.delenv("VLLM_NO_USAGE_STATS")
_re_import_modules() _re_import_modules()
pass pass

View File

@ -10,7 +10,7 @@ import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
try_to_load_from_cache) try_to_load_from_cache)
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
LocalEntryNotFoundError, HFValidationError, LocalEntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError) RevisionNotFoundError)
from torch import nn from torch import nn
@ -265,6 +265,24 @@ def get_config(
return config return config
def try_get_local_file(model: Union[str, Path],
file_name: str,
revision: Optional[str] = 'main') -> Optional[Path]:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
else:
try:
cached_filepath = try_to_load_from_cache(repo_id=model,
filename=file_name,
revision=revision)
if isinstance(cached_filepath, str):
return Path(cached_filepath)
except HFValidationError:
...
return None
def get_hf_file_to_dict(file_name: str, def get_hf_file_to_dict(file_name: str,
model: Union[str, Path], model: Union[str, Path],
revision: Optional[str] = 'main'): revision: Optional[str] = 'main'):
@ -281,21 +299,18 @@ def get_hf_file_to_dict(file_name: str,
- config_dict (dict): A dictionary containing - config_dict (dict): A dictionary containing
the contents of the downloaded file. the contents of the downloaded file.
""" """
file_path = Path(model) / file_name
if file_or_path_exists(model=model, file_path = try_get_local_file(model=model,
config_name=file_name, file_name=file_name,
revision=revision):
if not file_path.is_file():
try:
hf_hub_file = hf_hub_download(model,
file_name,
revision=revision) revision=revision)
if file_path is None and file_or_path_exists(
model=model, config_name=file_name, revision=revision):
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except (RepositoryNotFoundError, RevisionNotFoundError, except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e: EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download", logger.debug("File or repository not found in hf_hub_download", e)
e)
return None return None
except HfHubHTTPError as e: except HfHubHTTPError as e:
logger.warning( logger.warning(
@ -306,8 +321,10 @@ def get_hf_file_to_dict(file_name: str,
return None return None
file_path = Path(hf_hub_file) file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path) as file: with open(file_path) as file:
return json.load(file) return json.load(file)
return None return None
@ -328,6 +345,11 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
""" """
modules_file_name = "modules.json" modules_file_name = "modules.json"
modules_dict = None
if file_or_path_exists(model=model,
config_name=modules_file_name,
revision=revision):
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)
if modules_dict is None: if modules_dict is None:
@ -404,17 +426,30 @@ def get_sentence_transformer_tokenizer_config(model: str,
"sentence_xlm-roberta_config.json", "sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json", "sentence_xlnet_config.json",
] ]
encoder_dict = None
for config_file in sentence_transformer_config_files:
if try_get_local_file(model=model,
file_name=config_file,
revision=revision) is not None:
encoder_dict = get_hf_file_to_dict(config_file, model, revision)
if encoder_dict:
break
if not encoder_dict:
try: try:
# If model is on HuggingfaceHub, get the repo files # If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) repo_files = list_repo_files(model,
revision=revision,
token=HF_TOKEN)
except Exception as e: except Exception as e:
logger.debug("Error getting repo files", e) logger.debug("Error getting repo files", e)
repo_files = [] repo_files = []
encoder_dict = None
for config_name in sentence_transformer_config_files: for config_name in sentence_transformer_config_files:
if config_name in repo_files or Path(model).exists(): if config_name in repo_files:
encoder_dict = get_hf_file_to_dict(config_name, model, revision) encoder_dict = get_hf_file_to_dict(config_name, model,
revision)
if encoder_dict: if encoder_dict:
break break