mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[NVIDIA] Auto detect modelopt quant and fix DSR1-FP4 weight loading (#22073)
This commit is contained in:
parent
c09efff976
commit
bdcb42e45d
@ -1108,6 +1108,21 @@ class ModelConfig:
|
||||
if quant_cfg is None:
|
||||
# compressed-tensors uses a "compression_config" key
|
||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||
|
||||
else:
|
||||
# Set quant_method for ModelOpt models.
|
||||
producer_name = quant_cfg.get("producer", {}).get("name")
|
||||
if producer_name == "modelopt":
|
||||
quant_algo = quant_cfg.get("quantization",
|
||||
{}).get("quant_algo")
|
||||
if quant_algo == "FP8":
|
||||
quant_cfg["quant_method"] = "modelopt"
|
||||
elif quant_algo == "NVFP4":
|
||||
quant_cfg["quant_method"] = "modelopt_fp4"
|
||||
elif quant_algo is not None:
|
||||
raise ValueError(
|
||||
f"Unknown ModelOpt quant algo: {quant_algo}")
|
||||
|
||||
return quant_cfg
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
|
||||
@ -919,9 +919,13 @@ class FusedMoE(torch.nn.Module):
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_w13_weight_scale(self, shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
param: torch.Tensor, tp_rank: int):
|
||||
def _load_combined_w13_weight_scale(self, shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
param: torch.Tensor, tp_rank: int):
|
||||
"""
|
||||
Load w13 weight scales assuming that w1 weight scales and w3 weight
|
||||
scales are stored in the same loaded_weight tensor.
|
||||
"""
|
||||
shard_size = param.shape[shard_dim]
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
@ -1168,24 +1172,43 @@ class FusedMoE(torch.nn.Module):
|
||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
||||
)
|
||||
|
||||
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
|
||||
per_tensor_conditions = (
|
||||
"weight_scale_2" in weight_name if uses_weight_scale_2 else
|
||||
"weight_scale" in weight_name) or "input_scale" in weight_name
|
||||
|
||||
if "w13_weight_scale" in weight_name:
|
||||
self._load_w13_weight_scale(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
tp_rank=self.tp_rank)
|
||||
elif per_tensor_conditions:
|
||||
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||
# weights scales.
|
||||
# Input scales are always per-tensor.
|
||||
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||
# "weight_scale" for per-tensor scales.
|
||||
is_per_tensor = ("weight_scale_2" in weight_name
|
||||
if uses_weight_scale_2 else "weight_scale"
|
||||
in weight_name) or "input_scale" in weight_name
|
||||
if is_per_tensor:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
elif "weight" in weight_name:
|
||||
return True if return_success else None
|
||||
|
||||
# If the weight is w13_weight_scale and w13_weight_scales are
|
||||
# combined into single loaded_weight, call
|
||||
# _load_combined_w13_weight_scale() to load it.
|
||||
# This is checked by comparing the hidden_out dims of the
|
||||
# loaded_weight and the param.
|
||||
if "w13_weight_scale" in weight_name:
|
||||
loaded_weight_hidden_out = loaded_weight.shape[-2]
|
||||
param_hidden_out = param.data.shape[-2] * self.tp_size
|
||||
if loaded_weight_hidden_out == param_hidden_out:
|
||||
self._load_combined_w13_weight_scale(
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# For other weights, call _load_model_weight_or_group_weight_scale()
|
||||
# to load it.
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
|
||||
@ -449,6 +449,20 @@ def get_config(
|
||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||
config.update({"architectures": [model_type]})
|
||||
|
||||
# ModelOpt 0.31.0 and after saves the quantization config in the model
|
||||
# config file.
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
|
||||
# ModelOpt 0.29.0 and before saves the quantization config in a separate
|
||||
# "hf_quant_config.json" in the same directory as the model config file.
|
||||
if quantization_config is None \
|
||||
and file_or_path_exists(model, "hf_quant_config.json", revision):
|
||||
quantization_config = get_hf_file_to_dict("hf_quant_config.json",
|
||||
model, revision)
|
||||
|
||||
if quantization_config is not None:
|
||||
config.quantization_config = quantization_config
|
||||
|
||||
if hf_overrides_kw:
|
||||
logger.debug("Overriding HF config with %s", hf_overrides_kw)
|
||||
config.update(hf_overrides_kw)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user