mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 14:18:44 +08:00
[Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (#24960)
Signed-off-by: toncao <cpatonn@gmail.com> Co-authored-by: toncao <cpatonn@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
b98219670f
commit
027d37df38
@ -72,17 +72,20 @@ class Qwen2MoeMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -123,7 +126,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@ -132,6 +136,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
@ -203,21 +208,19 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -296,12 +299,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = Qwen2MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
|
||||
@ -138,6 +138,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user