diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 610e6a620ade2..0c5961561a7d9 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -23,6 +23,7 @@ import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load, load_file, safe_open, save_file from tqdm.auto import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm import envs from vllm.config import ModelConfig @@ -448,12 +449,31 @@ def download_weights_from_hf( fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - # Use the first pattern found in the HF repo's files. - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break + # If downloading safetensors and an index file exists, use the + # specific file names from the index to avoid downloading + # unnecessary files (e.g., from subdirectories like "original/"). + index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}" + if "*.safetensors" in allow_patterns and index_file in file_list: + index_path = hf_hub_download( + repo_id=model_name_or_path, + filename=SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + revision=revision, + ) + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + if weight_map: + # Extra [] so that weight_map files are treated as a + # single allow_pattern in the loop below + allow_patterns = [list(set(weight_map.values()))] # type: ignore[list-item] + else: + allow_patterns = ["*.safetensors"] + else: + # Use the first pattern found in the HF repo's files. + for pattern in allow_patterns: + if fnmatch.filter(file_list, pattern): + allow_patterns = [pattern] + break except Exception as e: logger.warning( "Failed to get file list for '%s'. Trying each pattern in " @@ -480,6 +500,9 @@ def download_weights_from_hf( ) # If we have downloaded weights for this allow_pattern, # we don't need to check the rest. + # allow_pattern can be a list (from weight_map) or str (glob) + if isinstance(allow_pattern, list): + break if any(Path(hf_folder).glob(allow_pattern)): break time_taken = time.perf_counter() - start_time