[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP (#27563)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
杰兮 2025-11-24 18:16:52 +08:00 committed by GitHub
parent 68dfe28eae
commit 8005e606bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 46 deletions

View File

@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable import typing
from collections.abc import Callable, Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
return self.model.compute_logits(hidden_states, spec_step_idx) return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [ stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
] ]
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if rocm_aiter_moe_shared_expert_enabled
else 0
),
) )
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None: if spec_layer is None:
continue continue
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
name = self._rewrite_spec_layer_name(spec_layer, name) name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if ("mlp.experts." in name) and name not in params_dict:
continue continue
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal # QKV fusion is optional, fall back to normal
@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
for mapping in expert_params_mapping: # Special handling: when AITER fusion_shared_experts is enabled,
param_name, weight_name, expert_id, shard_id = mapping # checkpoints may provide a single widened shared_experts tensor
if weight_name not in name: # without explicit expert indices
continue # (e.g. ...mlp.shared_experts.gate_proj.weight).
name = name.replace(weight_name, param_name) # For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
param = params_dict[name] # appended expert slots mlp.experts.{n_routed_experts + j}.*
weight_loader = param.weight_loader # accordingly.
weight_loader( num_chunks = 1
param, if is_fusion_moe_shared_experts_layer:
loaded_weight, num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
name, # Determine split axis based on op type
shard_id=shard_id, # gate/up: ColumnParallel → split along dim 0
expert_id=expert_id, # down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
) )
break chunk_size = total // num_chunks
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict) for j in range(num_chunks):
if name is None: chunk_name = name
continue weight_to_load = loaded_weight
# According to DeepSeek-V3 Technical Report, MTP modules if is_fusion_moe_shared_experts_layer:
# shares embedding layer. We only load the first weights. if split_dim == 0:
if ( weight_to_load = loaded_weight[
spec_layer != self.model.mtp_start_layer_idx j * chunk_size : (j + 1) * chunk_size, :
and ".layers" not in name ]
): else:
continue weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
param = params_dict[name] # Use expert_params_mapping to locate the destination
weight_loader = getattr( # param and delegate to its expert-aware weight_loader
param, "weight_loader", default_weight_loader # with expert_id.
) for mapping in expert_params_mapping:
weight_loader(param, loaded_weight) param_name, weight_name, expert_id, shard_id = mapping
loaded_params.add(name) if weight_name not in chunk_name:
continue
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:

View File

@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( is_fusion_moe_shared_experts_layer = (
"mlp.shared_experts" in name rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
) )
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if ("mlp.experts." in name) and name not in params_dict:
continue continue
if is_fuse_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
continue continue
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.* # appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly. # accordingly.
num_chunks = 1 num_chunks = 1
if is_fuse_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type # Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0 # gate/up: ColumnParallel → split along dim 0
@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name chunk_name = name
weight_to_load = loaded_weight weight_to_load = loaded_weight
if is_fuse_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
if split_dim == 0: if split_dim == 0:
weight_to_load = loaded_weight[ weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, : j * chunk_size : (j + 1) * chunk_size, :
@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True, return_success=True,
) )
if success: if success:
if not is_fuse_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
name = name_mapped name = name_mapped
else: else:
loaded_params.add(name_mapped) loaded_params.add(name_mapped)
@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params