Add ModelOpt Qwen3 nvfp4 support (#20101)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu 2025-08-07 19:18:19 -07:00 committed by GitHub
parent e2c8f1edec
commit d57dc2364e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 37 deletions

View File

@ -764,39 +764,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
# Define scale name mapping patterns in order of precedence
scale_mapping_patterns = [
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
r".self_attn.attn.\2_scale"),
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
# Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"
),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([kv])_scale$", r".attn.\1_scale"),
]
# Also support qkv_proj scale parameters (from stacked parameter processing)
qkv_proj_scale_names = [
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
if any(mo_scale_name in name
for mo_scale_name in modelopt_scale_names):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}")
elif any(qkv_scale_name in name
for qkv_scale_name in qkv_proj_scale_names):
# Handle qkv_proj scale parameters
remapped_name = name.replace(
f".self_attn.qkv_proj{scale_name}",
f".self_attn.attn{scale_name}")
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
logger.warning_once(
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
scale_name,
name,
remapped_name,
scale_name,
)
return None
return remapped_name
# Check if name ends with k_scale or v_scale
if name.endswith((".k_scale", ".v_scale")):
import regex as re
for pattern, replacement in scale_mapping_patterns:
if re.search(pattern, name):
remapped_name = re.sub(pattern, replacement, name)
if remapped_name not in params_dict:
scale_type = name.split(".")[-1]
logger.warning_once(
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
scale_type,
name,
remapped_name,
scale_type,
)
return None
return remapped_name
# If there were no matches, return the untouched param name
return name

View File

@ -408,9 +408,18 @@ class Qwen2Model(nn.Module):
continue
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.

View File

@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
@ -471,12 +472,21 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False