mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 19:15:35 +08:00
Simplify weight loading logic (#2133)
This commit is contained in:
parent
2acd76f346
commit
eed74a558f
@ -122,15 +122,10 @@ class ModelConfig:
|
|||||||
|
|
||||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||||
architectures = getattr(self.hf_config, "architectures", [])
|
architectures = getattr(self.hf_config, "architectures", [])
|
||||||
if "MixtralForCausalLM" in architectures:
|
if "MixtralForCausalLM" in architectures and load_format == "pt":
|
||||||
if load_format == "pt":
|
raise ValueError(
|
||||||
raise ValueError(
|
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
"Please use the 'safetensors' format instead. ")
|
||||||
"Please use the 'safetensors' format instead. ")
|
|
||||||
elif load_format == "auto":
|
|
||||||
# Do not fall back to pt weights.
|
|
||||||
load_format = "safetensors"
|
|
||||||
|
|
||||||
self.load_format = load_format
|
self.load_format = load_format
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
|
|||||||
@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format, revision):
|
model_name_or_path,
|
||||||
|
cache_dir,
|
||||||
|
load_format,
|
||||||
|
revision,
|
||||||
|
fall_back_to_pt=False):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
|||||||
@ -125,15 +125,29 @@ def get_quant_config(
|
|||||||
def prepare_hf_model_weights(
|
def prepare_hf_model_weights(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_safetensors: bool = False,
|
load_format: str = "auto",
|
||||||
fall_back_to_pt: bool = True,
|
fall_back_to_pt: bool = True,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
) -> Tuple[str, List[str], bool]:
|
) -> Tuple[str, List[str], bool]:
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
use_safetensors = False
|
||||||
# Some quantized models use .pt files for storing the weights.
|
# Some quantized models use .pt files for storing the weights.
|
||||||
allow_patterns = ["*.safetensors"
|
if load_format == "auto":
|
||||||
] if use_safetensors else ["*.bin", "*.pt"]
|
allow_patterns = ["*.safetensors", "*.bin"]
|
||||||
|
elif load_format == "safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
elif load_format == "pt":
|
||||||
|
allow_patterns = ["*.pt"]
|
||||||
|
elif load_format == "npcache":
|
||||||
|
allow_patterns = ["*.bin"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
|
||||||
|
if fall_back_to_pt:
|
||||||
|
allow_patterns += [".pt"]
|
||||||
|
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
@ -148,6 +162,10 @@ def prepare_hf_model_weights(
|
|||||||
hf_weights_files: List[str] = []
|
hf_weights_files: List[str] = []
|
||||||
for pattern in allow_patterns:
|
for pattern in allow_patterns:
|
||||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
|
if len(hf_weights_files) > 0:
|
||||||
|
if pattern == "*.safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
break
|
||||||
if not use_safetensors:
|
if not use_safetensors:
|
||||||
# Exclude files that are not needed for inference.
|
# Exclude files that are not needed for inference.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||||
@ -163,13 +181,6 @@ def prepare_hf_model_weights(
|
|||||||
if not any(f.endswith(x) for x in blacklist)
|
if not any(f.endswith(x) for x in blacklist)
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
|
||||||
return prepare_hf_model_weights(model_name_or_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
use_safetensors=False,
|
|
||||||
fall_back_to_pt=False,
|
|
||||||
revision=revision)
|
|
||||||
|
|
||||||
if len(hf_weights_files) == 0:
|
if len(hf_weights_files) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
@ -182,30 +193,16 @@ def hf_model_weights_iterator(
|
|||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
fall_back_to_pt: Optional[bool] = True,
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
use_safetensors = False
|
|
||||||
use_np_cache = False
|
|
||||||
fall_back_to_pt = False
|
|
||||||
if load_format == "auto":
|
|
||||||
use_safetensors = True
|
|
||||||
fall_back_to_pt = True
|
|
||||||
elif load_format == "safetensors":
|
|
||||||
use_safetensors = True
|
|
||||||
elif load_format == "pt":
|
|
||||||
pass
|
|
||||||
elif load_format == "npcache":
|
|
||||||
use_np_cache = True
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown load_format: {load_format}")
|
|
||||||
|
|
||||||
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_safetensors=use_safetensors,
|
load_format=load_format,
|
||||||
fall_back_to_pt=fall_back_to_pt,
|
fall_back_to_pt=fall_back_to_pt,
|
||||||
revision=revision)
|
revision=revision)
|
||||||
|
|
||||||
if use_np_cache:
|
if load_format == "npcache":
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
assert use_safetensors is False
|
assert use_safetensors is False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user