mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:45:16 +08:00
[NVIDIA] Auto detect modelopt quant and fix DSR1-FP4 weight loading (#22073)
This commit is contained in:
parent
c09efff976
commit
bdcb42e45d
@ -1108,6 +1108,21 @@ class ModelConfig:
|
|||||||
if quant_cfg is None:
|
if quant_cfg is None:
|
||||||
# compressed-tensors uses a "compression_config" key
|
# compressed-tensors uses a "compression_config" key
|
||||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Set quant_method for ModelOpt models.
|
||||||
|
producer_name = quant_cfg.get("producer", {}).get("name")
|
||||||
|
if producer_name == "modelopt":
|
||||||
|
quant_algo = quant_cfg.get("quantization",
|
||||||
|
{}).get("quant_algo")
|
||||||
|
if quant_algo == "FP8":
|
||||||
|
quant_cfg["quant_method"] = "modelopt"
|
||||||
|
elif quant_algo == "NVFP4":
|
||||||
|
quant_cfg["quant_method"] = "modelopt_fp4"
|
||||||
|
elif quant_algo is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown ModelOpt quant algo: {quant_algo}")
|
||||||
|
|
||||||
return quant_cfg
|
return quant_cfg
|
||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
|
|||||||
@ -919,9 +919,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
elif shard_id == "w2":
|
elif shard_id == "w2":
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
def _load_w13_weight_scale(self, shard_dim: int,
|
def _load_combined_w13_weight_scale(self, shard_dim: int,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
param: torch.Tensor, tp_rank: int):
|
param: torch.Tensor, tp_rank: int):
|
||||||
|
"""
|
||||||
|
Load w13 weight scales assuming that w1 weight scales and w3 weight
|
||||||
|
scales are stored in the same loaded_weight tensor.
|
||||||
|
"""
|
||||||
shard_size = param.shape[shard_dim]
|
shard_size = param.shape[shard_dim]
|
||||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||||
shard_size)
|
shard_size)
|
||||||
@ -1168,24 +1172,43 @@ class FusedMoE(torch.nn.Module):
|
|||||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
||||||
)
|
)
|
||||||
|
|
||||||
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
|
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||||
per_tensor_conditions = (
|
# weights scales.
|
||||||
"weight_scale_2" in weight_name if uses_weight_scale_2 else
|
# Input scales are always per-tensor.
|
||||||
"weight_scale" in weight_name) or "input_scale" in weight_name
|
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||||
|
# "weight_scale" for per-tensor scales.
|
||||||
if "w13_weight_scale" in weight_name:
|
is_per_tensor = ("weight_scale_2" in weight_name
|
||||||
self._load_w13_weight_scale(shard_dim=shard_dim,
|
if uses_weight_scale_2 else "weight_scale"
|
||||||
loaded_weight=loaded_weight,
|
in weight_name) or "input_scale" in weight_name
|
||||||
param=param,
|
if is_per_tensor:
|
||||||
tp_rank=self.tp_rank)
|
|
||||||
elif per_tensor_conditions:
|
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
param=param,
|
param=param,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
elif "weight" in weight_name:
|
return True if return_success else None
|
||||||
|
|
||||||
|
# If the weight is w13_weight_scale and w13_weight_scales are
|
||||||
|
# combined into single loaded_weight, call
|
||||||
|
# _load_combined_w13_weight_scale() to load it.
|
||||||
|
# This is checked by comparing the hidden_out dims of the
|
||||||
|
# loaded_weight and the param.
|
||||||
|
if "w13_weight_scale" in weight_name:
|
||||||
|
loaded_weight_hidden_out = loaded_weight.shape[-2]
|
||||||
|
param_hidden_out = param.data.shape[-2] * self.tp_size
|
||||||
|
if loaded_weight_hidden_out == param_hidden_out:
|
||||||
|
self._load_combined_w13_weight_scale(
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
param=param,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
)
|
||||||
|
return True if return_success else None
|
||||||
|
|
||||||
|
# For other weights, call _load_model_weight_or_group_weight_scale()
|
||||||
|
# to load it.
|
||||||
|
if "weight" in weight_name:
|
||||||
self._load_model_weight_or_group_weight_scale(
|
self._load_model_weight_or_group_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
|
|||||||
@ -449,6 +449,20 @@ def get_config(
|
|||||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||||
config.update({"architectures": [model_type]})
|
config.update({"architectures": [model_type]})
|
||||||
|
|
||||||
|
# ModelOpt 0.31.0 and after saves the quantization config in the model
|
||||||
|
# config file.
|
||||||
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
|
||||||
|
# ModelOpt 0.29.0 and before saves the quantization config in a separate
|
||||||
|
# "hf_quant_config.json" in the same directory as the model config file.
|
||||||
|
if quantization_config is None \
|
||||||
|
and file_or_path_exists(model, "hf_quant_config.json", revision):
|
||||||
|
quantization_config = get_hf_file_to_dict("hf_quant_config.json",
|
||||||
|
model, revision)
|
||||||
|
|
||||||
|
if quantization_config is not None:
|
||||||
|
config.quantization_config = quantization_config
|
||||||
|
|
||||||
if hf_overrides_kw:
|
if hf_overrides_kw:
|
||||||
logger.debug("Overriding HF config with %s", hf_overrides_kw)
|
logger.debug("Overriding HF config with %s", hf_overrides_kw)
|
||||||
config.update(hf_overrides_kw)
|
config.update(hf_overrides_kw)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user