mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:44:57 +08:00
Simplify weight loading logic (#2133)
This commit is contained in:
parent
2acd76f346
commit
eed74a558f
@ -122,15 +122,10 @@ class ModelConfig:
|
||||
|
||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if "MixtralForCausalLM" in architectures:
|
||||
if load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
elif load_format == "auto":
|
||||
# Do not fall back to pt weights.
|
||||
load_format = "safetensors"
|
||||
|
||||
if "MixtralForCausalLM" in architectures and load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
|
||||
@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
model_name_or_path,
|
||||
cache_dir,
|
||||
load_format,
|
||||
revision,
|
||||
fall_back_to_pt=False):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
|
||||
@ -125,15 +125,29 @@ def get_quant_config(
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
load_format: str = "auto",
|
||||
fall_back_to_pt: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
allow_patterns = ["*.safetensors"
|
||||
] if use_safetensors else ["*.bin", "*.pt"]
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += [".pt"]
|
||||
|
||||
if not is_local:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
@ -148,6 +162,10 @@ def prepare_hf_model_weights(
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
@ -163,13 +181,6 @@ def prepare_hf_model_weights(
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||
return prepare_hf_model_weights(model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=False,
|
||||
fall_back_to_pt=False,
|
||||
revision=revision)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
@ -182,30 +193,16 @@ def hf_model_weights_iterator(
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
fall_back_to_pt: Optional[bool] = True,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
use_safetensors = False
|
||||
use_np_cache = False
|
||||
fall_back_to_pt = False
|
||||
if load_format == "auto":
|
||||
use_safetensors = True
|
||||
fall_back_to_pt = True
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
elif load_format == "pt":
|
||||
pass
|
||||
elif load_format == "npcache":
|
||||
use_np_cache = True
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=use_safetensors,
|
||||
load_format=load_format,
|
||||
fall_back_to_pt=fall_back_to_pt,
|
||||
revision=revision)
|
||||
|
||||
if use_np_cache:
|
||||
if load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user