mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:25:45 +08:00
Add ModelOpt Qwen3 nvfp4 support (#20101)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
parent
e2c8f1edec
commit
d57dc2364e
@ -764,36 +764,38 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
return remapped_name
|
return remapped_name
|
||||||
|
|
||||||
possible_scale_names = [".k_scale", ".v_scale"]
|
# Define scale name mapping patterns in order of precedence
|
||||||
modelopt_scale_names = [
|
scale_mapping_patterns = [
|
||||||
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
|
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
|
||||||
|
# .self_attn.attn.{k,v}_scale
|
||||||
|
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
|
||||||
|
r".self_attn.attn.\2_scale"),
|
||||||
|
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
|
||||||
|
# .self_attn.attn.{k,v}_scale
|
||||||
|
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
|
||||||
|
# Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
|
||||||
|
# .self_attn.attn.{k,v}_scale
|
||||||
|
(r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"
|
||||||
|
),
|
||||||
|
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
|
||||||
|
(r"\.([kv])_scale$", r".attn.\1_scale"),
|
||||||
]
|
]
|
||||||
# Also support qkv_proj scale parameters (from stacked parameter processing)
|
|
||||||
qkv_proj_scale_names = [
|
# Check if name ends with k_scale or v_scale
|
||||||
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
|
if name.endswith((".k_scale", ".v_scale")):
|
||||||
]
|
import regex as re
|
||||||
for scale_name in possible_scale_names:
|
|
||||||
if name.endswith(scale_name):
|
for pattern, replacement in scale_mapping_patterns:
|
||||||
if any(mo_scale_name in name
|
if re.search(pattern, name):
|
||||||
for mo_scale_name in modelopt_scale_names):
|
remapped_name = re.sub(pattern, replacement, name)
|
||||||
remapped_name = name.replace(
|
|
||||||
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
|
||||||
f".self_attn.attn{scale_name}")
|
|
||||||
elif any(qkv_scale_name in name
|
|
||||||
for qkv_scale_name in qkv_proj_scale_names):
|
|
||||||
# Handle qkv_proj scale parameters
|
|
||||||
remapped_name = name.replace(
|
|
||||||
f".self_attn.qkv_proj{scale_name}",
|
|
||||||
f".self_attn.attn{scale_name}")
|
|
||||||
else:
|
|
||||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
|
||||||
if remapped_name not in params_dict:
|
if remapped_name not in params_dict:
|
||||||
|
scale_type = name.split(".")[-1]
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
|
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
|
||||||
scale_name,
|
scale_type,
|
||||||
name,
|
name,
|
||||||
remapped_name,
|
remapped_name,
|
||||||
scale_name,
|
scale_type,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
return remapped_name
|
return remapped_name
|
||||||
|
|||||||
@ -408,8 +408,17 @@ class Qwen2Model(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
if name.endswith("scale"):
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
if weight_loader == default_weight_loader:
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -471,11 +472,20 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
if name.endswith("scale"):
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
if weight_loader == default_weight_loader:
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user