Add ModelOpt Qwen3 nvfp4 support (#20101)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu 2025-08-07 19:18:19 -07:00 committed by GitHub
parent e2c8f1edec
commit d57dc2364e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 37 deletions

View File

@ -764,39 +764,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return None return None
return remapped_name return remapped_name
possible_scale_names = [".k_scale", ".v_scale"] # Define scale name mapping patterns in order of precedence
modelopt_scale_names = [ scale_mapping_patterns = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
r".self_attn.attn.\2_scale"),
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
# Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"
),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([kv])_scale$", r".attn.\1_scale"),
] ]
# Also support qkv_proj scale parameters (from stacked parameter processing)
qkv_proj_scale_names = [ # Check if name ends with k_scale or v_scale
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale" if name.endswith((".k_scale", ".v_scale")):
] import regex as re
for scale_name in possible_scale_names:
if name.endswith(scale_name): for pattern, replacement in scale_mapping_patterns:
if any(mo_scale_name in name if re.search(pattern, name):
for mo_scale_name in modelopt_scale_names): remapped_name = re.sub(pattern, replacement, name)
remapped_name = name.replace( if remapped_name not in params_dict:
f".self_attn.{scale_name[1]}_proj{scale_name}", scale_type = name.split(".")[-1]
f".self_attn.attn{scale_name}") logger.warning_once(
elif any(qkv_scale_name in name "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
for qkv_scale_name in qkv_proj_scale_names): scale_type,
# Handle qkv_proj scale parameters name,
remapped_name = name.replace( remapped_name,
f".self_attn.qkv_proj{scale_name}", scale_type,
f".self_attn.attn{scale_name}") )
else: return None
remapped_name = name.replace(scale_name, f".attn{scale_name}") return remapped_name
if remapped_name not in params_dict:
logger.warning_once(
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
scale_name,
name,
remapped_name,
scale_name,
)
return None
return remapped_name
# If there were no matches, return the untouched param name # If there were no matches, return the untouched param name
return name return name

View File

@ -408,9 +408,18 @@ class Qwen2Model(nn.Module):
continue continue
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = getattr(param, "weight_loader",
weight_loader(param, loaded_weight, shard_id) default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.

View File

@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -471,12 +472,21 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict: if name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = getattr(param, "weight_loader",
weight_loader(param, loaded_weight, shard_id) default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break break
else: else:
is_expert_weight = False is_expert_weight = False