mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 11:45:59 +08:00
[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP (#27563)
Signed-off-by: zhyajie <yajizhan@amd.com> Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
parent
68dfe28eae
commit
8005e606bf
@ -1,15 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
return self.model.compute_logits(hidden_states, spec_step_idx)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rocm_aiter_moe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.config.n_shared_experts
|
||||
if rocm_aiter_moe_shared_expert_enabled
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
# Special handling: when AITER fusion_shared_experts is enabled,
|
||||
# checkpoints may provide a single widened shared_experts tensor
|
||||
# without explicit expert indices
|
||||
# (e.g. ...mlp.shared_experts.gate_proj.weight).
|
||||
# For models with multiple shared experts, split that tensor
|
||||
# evenly into per-shared-expert slices and load them into
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
# down: RowParallel → split along dim 1
|
||||
split_dim = 1 if "down_proj.weight" in name else 0
|
||||
total = loaded_weight.shape[split_dim]
|
||||
assert total % num_chunks == 0, (
|
||||
f"Shared expert weight dim {total} "
|
||||
f"not divisible by num_chunks {num_chunks}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
chunk_size = total // num_chunks
|
||||
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for j in range(num_chunks):
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (
|
||||
spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name
|
||||
):
|
||||
continue
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
if split_dim == 0:
|
||||
weight_to_load = loaded_weight[
|
||||
j * chunk_size : (j + 1) * chunk_size, :
|
||||
]
|
||||
else:
|
||||
weight_to_load = loaded_weight[
|
||||
:, j * chunk_size : (j + 1) * chunk_size
|
||||
]
|
||||
# Synthesize an expert-style name so expert mapping
|
||||
# can route it
|
||||
chunk_name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts + j}",
|
||||
)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
# Use expert_params_mapping to locate the destination
|
||||
# param and delegate to its expert-aware weight_loader
|
||||
# with expert_id.
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in chunk_name:
|
||||
continue
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
weight_to_load,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (
|
||||
spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name
|
||||
):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
|
||||
@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
|
||||
"mlp.shared_experts" in name
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fuse_shared_experts_layer:
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fuse_shared_experts_layer:
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
|
||||
if is_fuse_shared_experts_layer:
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
if split_dim == 0:
|
||||
weight_to_load = loaded_weight[
|
||||
j * chunk_size : (j + 1) * chunk_size, :
|
||||
@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fuse_shared_experts_layer:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fuse_shared_experts_layer:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user