[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP (#27563)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
杰兮 2025-11-24 18:16:52 +08:00 committed by GitHub
parent 68dfe28eae
commit 8005e606bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 46 deletions

View File

@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if rocm_aiter_moe_shared_expert_enabled
else 0
),
)
params_dict = dict(self.named_parameters())
@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
chunk_size = total // num_chunks
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
]
else:
weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:

View File

@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
"mlp.shared_experts" in name
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fuse_shared_experts_layer:
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fuse_shared_experts_layer:
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name
weight_to_load = loaded_weight
if is_fuse_shared_experts_layer:
if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True,
)
if success:
if not is_fuse_shared_experts_layer:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer:
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params