[NVIDIA] Auto detect modelopt quant and fix DSR1-FP4 weight loading (#22073)

This commit is contained in:
Po-Han Huang (NVIDIA) 2025-08-05 09:02:55 +08:00 committed by GitHub
parent c09efff976
commit bdcb42e45d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 15 deletions

View File

@ -1108,6 +1108,21 @@ class ModelConfig:
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
else:
# Set quant_method for ModelOpt models.
producer_name = quant_cfg.get("producer", {}).get("name")
if producer_name == "modelopt":
quant_algo = quant_cfg.get("quantization",
{}).get("quant_algo")
if quant_algo == "FP8":
quant_cfg["quant_method"] = "modelopt"
elif quant_algo == "NVFP4":
quant_cfg["quant_method"] = "modelopt_fp4"
elif quant_algo is not None:
raise ValueError(
f"Unknown ModelOpt quant algo: {quant_algo}")
return quant_cfg
def _verify_quantization(self) -> None:

View File

@ -919,9 +919,13 @@ class FusedMoE(torch.nn.Module):
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_w13_weight_scale(self, shard_dim: int,
loaded_weight: torch.Tensor,
param: torch.Tensor, tp_rank: int):
def _load_combined_w13_weight_scale(self, shard_dim: int,
loaded_weight: torch.Tensor,
param: torch.Tensor, tp_rank: int):
"""
Load w13 weight scales assuming that w1 weight scales and w3 weight
scales are stored in the same loaded_weight tensor.
"""
shard_size = param.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
@ -1168,24 +1172,43 @@ class FusedMoE(torch.nn.Module):
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
)
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
per_tensor_conditions = (
"weight_scale_2" in weight_name if uses_weight_scale_2 else
"weight_scale" in weight_name) or "input_scale" in weight_name
if "w13_weight_scale" in weight_name:
self._load_w13_weight_scale(shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank)
elif per_tensor_conditions:
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# weights scales.
# Input scales are always per-tensor.
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
is_per_tensor = ("weight_scale_2" in weight_name
if uses_weight_scale_2 else "weight_scale"
in weight_name) or "input_scale" in weight_name
if is_per_tensor:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
return True if return_success else None
# If the weight is w13_weight_scale and w13_weight_scales are
# combined into single loaded_weight, call
# _load_combined_w13_weight_scale() to load it.
# This is checked by comparing the hidden_out dims of the
# loaded_weight and the param.
if "w13_weight_scale" in weight_name:
loaded_weight_hidden_out = loaded_weight.shape[-2]
param_hidden_out = param.data.shape[-2] * self.tp_size
if loaded_weight_hidden_out == param_hidden_out:
self._load_combined_w13_weight_scale(
shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank,
)
return True if return_success else None
# For other weights, call _load_model_weight_or_group_weight_scale()
# to load it.
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,

View File

@ -449,6 +449,20 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
# ModelOpt 0.31.0 and after saves the quantization config in the model
# config file.
quantization_config = config_dict.get("quantization_config", None)
# ModelOpt 0.29.0 and before saves the quantization config in a separate
# "hf_quant_config.json" in the same directory as the model config file.
if quantization_config is None \
and file_or_path_exists(model, "hf_quant_config.json", revision):
quantization_config = get_hf_file_to_dict("hf_quant_config.json",
model, revision)
if quantization_config is not None:
config.quantization_config = quantization_config
if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)
config.update(hf_overrides_kw)