mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 14:41:58 +08:00
Filter safetensors files to download if .safetensors.index.json exists (#30537)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
96bf50a2c0
commit
100f93d2be
@ -23,6 +23,7 @@ import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load, load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
@ -448,12 +449,31 @@ def download_weights_from_hf(
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# Use the first pattern found in the HF repo's files.
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
# If downloading safetensors and an index file exists, use the
|
||||
# specific file names from the index to avoid downloading
|
||||
# unnecessary files (e.g., from subdirectories like "original/").
|
||||
index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}"
|
||||
if "*.safetensors" in allow_patterns and index_file in file_list:
|
||||
index_path = hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
)
|
||||
with open(index_path) as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
if weight_map:
|
||||
# Extra [] so that weight_map files are treated as a
|
||||
# single allow_pattern in the loop below
|
||||
allow_patterns = [list(set(weight_map.values()))] # type: ignore[list-item]
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
else:
|
||||
# Use the first pattern found in the HF repo's files.
|
||||
for pattern in allow_patterns:
|
||||
if fnmatch.filter(file_list, pattern):
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get file list for '%s'. Trying each pattern in "
|
||||
@ -480,6 +500,9 @@ def download_weights_from_hf(
|
||||
)
|
||||
# If we have downloaded weights for this allow_pattern,
|
||||
# we don't need to check the rest.
|
||||
# allow_pattern can be a list (from weight_map) or str (glob)
|
||||
if isinstance(allow_pattern, list):
|
||||
break
|
||||
if any(Path(hf_folder).glob(allow_pattern)):
|
||||
break
|
||||
time_taken = time.perf_counter() - start_time
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user