Filter safetensors files to download if .safetensors.index.json exists (#30537)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-12-18 09:51:17 -05:00 committed by GitHub
parent 96bf50a2c0
commit 100f93d2be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,6 +23,7 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import ModelConfig
@ -448,12 +449,31 @@ def download_weights_from_hf(
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# Use the first pattern found in the HF repo's files.
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
# If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading
# unnecessary files (e.g., from subdirectories like "original/").
index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}"
if "*.safetensors" in allow_patterns and index_file in file_list:
index_path = hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
revision=revision,
)
with open(index_path) as f:
weight_map = json.load(f)["weight_map"]
if weight_map:
# Extra [] so that weight_map files are treated as a
# single allow_pattern in the loop below
allow_patterns = [list(set(weight_map.values()))] # type: ignore[list-item]
else:
allow_patterns = ["*.safetensors"]
else:
# Use the first pattern found in the HF repo's files.
for pattern in allow_patterns:
if fnmatch.filter(file_list, pattern):
allow_patterns = [pattern]
break
except Exception as e:
logger.warning(
"Failed to get file list for '%s'. Trying each pattern in "
@ -480,6 +500,9 @@ def download_weights_from_hf(
)
# If we have downloaded weights for this allow_pattern,
# we don't need to check the rest.
# allow_pattern can be a list (from weight_map) or str (glob)
if isinstance(allow_pattern, list):
break
if any(Path(hf_folder).glob(allow_pattern)):
break
time_taken = time.perf_counter() - start_time