[ROCm][FEAT] Fuse DeepSeek shared experts into AITER fused_moe ops (#24097)

Signed-off-by: chenjun <junchen2@amd.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: valarLip <103567126+valarLip@users.noreply.github.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
kliuae 2025-10-16 10:41:34 +08:00 committed by GitHub
parent 0ecc553ee6
commit 1317034379
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 347 additions and 83 deletions

View File

@ -85,7 +85,7 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
else:
expected_test_local = base_experts
test_local_experts, test_expert_map = determine_expert_map(
test_local_experts, test_expert_map, _ = determine_expert_map(
ep_size=test_ep_size,
ep_rank=ep_rank,
global_num_experts=test_global_experts,
@ -116,7 +116,7 @@ def test_expert_placement_edge_cases(expert_placement_strategy, world_size):
"""Test edge cases for round_robin expert placement."""
# Test case 1: ep_size = 1 (should return None for expert_map)
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, _ = determine_expert_map(
ep_size=1,
ep_rank=0,
global_num_experts=8,
@ -217,7 +217,7 @@ def test_determine_expert_map_comprehensive():
expected_local,
expected_map_pattern,
) in test_cases:
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, _ = determine_expert_map(
ep_size=ep_size,
ep_rank=ep_rank,
global_num_experts=global_num_experts,

View File

@ -217,7 +217,7 @@ def test_moe_permute_unpermute(
expert_map = None
n_local_expert = n_expert
if ep_size != 1:
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0)

View File

@ -113,6 +113,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
@ -914,6 +915,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
in ("true", "1")
),
# Whether to use aiter fusion shared experts ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
in ("true", "1")
),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")

View File

@ -5,6 +5,7 @@ from abc import abstractmethod
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import Literal, get_args, overload
import torch
@ -12,7 +13,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
@ -39,6 +40,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
@ -87,7 +90,7 @@ else:
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk,
rocm_aiter_grouped_topk as grouped_topk_aiter,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
@ -634,6 +637,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
)
if self.rocm_aiter_moe_enabled:
@ -860,7 +864,8 @@ def determine_expert_map(
ep_rank: int,
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
) -> tuple[int, torch.Tensor | None]:
num_fused_shared_experts: int = 0,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
@ -882,10 +887,16 @@ def determine_expert_map(
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
- expert_mask (Optional[torch.Tensor]): A tensor of shape
(global_num_experts + num_fused_shared_experts + 1,)
containing 1 for experts assigned to the current rank
and 0 for sentinel.
Returns None if ep_size is 1.
Used only when AITER MOE is enabled.
"""
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
return (global_num_experts, None, None)
# Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size
@ -914,7 +925,26 @@ def determine_expert_map(
f"'{expert_placement_strategy}', expected one of "
f"{get_args(ExpertPlacementStrategy)}"
)
return (local_num_experts, expert_map)
expert_mask = None
if is_rocm_aiter_moe_enabled():
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
expert_mask[-1] = 0
expert_mask[:global_num_experts] = expert_map > -1
expert_map = torch.cat(
(
expert_map,
torch.tensor(
[local_num_experts + i for i in range(num_fused_shared_experts)],
dtype=torch.int32,
),
),
dim=0,
)
return (local_num_experts, expert_map, expert_mask)
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
@ -1040,6 +1070,7 @@ class FusedMoE(CustomOp):
zero_expert_num: int | None = 0,
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
):
super().__init__()
if params_dtype is None:
@ -1096,6 +1127,22 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None
and is_rocm_aiter_fusion_shared_expert_enabled()
else 0
)
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
and self.num_fused_shared_experts != 0
):
raise ValueError(
"n_shared_experts is only supported on ROCm aiter when "
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
)
# Determine expert maps
if self.use_ep:
if self.enable_eplb:
@ -1129,14 +1176,16 @@ class FusedMoE(CustomOp):
expert_placement_strategy = "linear"
self.expert_map: torch.Tensor | None
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
@ -1150,10 +1199,18 @@ class FusedMoE(CustomOp):
get_compressed_expert_map(self.expert_map),
)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts, None)
self.local_num_experts, self.expert_map, self.expert_mask = (
self.global_num_experts,
None,
None,
)
self.top_k = top_k
self._init_aiter_shared_experts_topK_buffer(
vllm_config=vllm_config, dp_size=dp_size_
)
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
@ -1327,13 +1384,18 @@ class FusedMoE(CustomOp):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
with self.expert_map.device:
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
)
def _load_per_tensor_weight_scale(
self,
@ -1504,6 +1566,24 @@ class FusedMoE(CustomOp):
return expert_id
return self.expert_map[expert_id].item()
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if is_rocm_aiter_fusion_shared_expert_enabled():
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
@overload
def weight_loader(
self,
@ -1866,6 +1946,7 @@ class FusedMoE(CustomOp):
global_num_experts: int | None = None,
zero_expert_num: int | None = None,
zero_expert_type: str | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
@ -1900,7 +1981,16 @@ class FusedMoE(CustomOp):
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
if is_rocm_aiter_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
grouped_topk_aiter,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
@ -2119,7 +2209,9 @@ class FusedMoE(CustomOp):
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
@ -2244,7 +2336,9 @@ class FusedMoE(CustomOp):
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from functools import cache
from functools import cache, lru_cache
import torch
@ -46,6 +46,69 @@ def is_rocm_aiter_moe_enabled() -> bool:
)
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
)
aiter_topK_meta_data = None
@lru_cache(maxsize=1)
def init_aiter_topK_meta_data(
n_routed_experts: int,
n_shared_experts: int,
top_k: int,
tp_rank: int,
tp_size: int,
shared_experts_score: float = 1.0,
max_num_tokens: int = 32768,
is_EP: bool = False,
):
global aiter_topK_meta_data
fake_expertid = n_routed_experts + n_shared_experts
# all layers reuse same buffer
# This extra element when EP is enabled is used as a sentinel
# to mask out shared expert processing for tokens not owned by
# the current EP rank. This is necessary to avoid double-processing
# of shared experts.
total_topk_ids = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.int32,
device="cuda",
)
ns_topk_ids, s_topk_ids = total_topk_ids.split(
[top_k, n_shared_experts + is_EP], dim=1
)
shared_expert_ids = [n_routed_experts + i for i in range(n_shared_experts + is_EP)]
if is_EP:
s_topk_ids_list = [
[fake_expertid] * (n_shared_experts + is_EP)
] * max_num_tokens
for i in range(tp_rank, max_num_tokens, tp_size):
s_topk_ids_list[i] = shared_expert_ids
else:
s_topk_ids_list = [
list(range(n_routed_experts, fake_expertid))
] * max_num_tokens
s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda")
total_topk_weights = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.float32,
device="cuda",
)
ns_topk_weights, s_topk_weights = total_topk_weights.split(
[top_k, n_shared_experts + is_EP], dim=1
)
s_topk_weights.fill_(shared_experts_score)
assert aiter_topK_meta_data is None, "AITER topK meta data is already initialized"
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -300,11 +363,33 @@ def rocm_aiter_grouped_topk(
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
device = hidden_states.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data "
"is called before this function."
)
total_topk_weights, total_topk_ids = aiter_topK_meta_data
assert total_topk_weights.shape[0] >= token, (
f"AITER topK meta data support {total_topk_weights.shape[0]} "
f"tokens which is determined by max_num_batched_tokens, "
f"but got {token} tokens now."
)
total_topk_weights = total_topk_weights[:token]
total_topk_ids = total_topk_ids[:token]
topk_weights, _ = total_topk_weights.split(
[topk, total_topk_weights.shape[1] - topk], dim=1
)
topk_ids, _ = total_topk_ids.split(
[topk, total_topk_ids.shape[1] - topk], dim=1
)
else:
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
@ -315,6 +400,7 @@ def rocm_aiter_grouped_topk(
num_expert_group,
topk_group,
renormalize,
routed_scaling_factor=routed_scaling_factor,
)
else:
assert scoring_func == "softmax" or scoring_func == "sigmoid"
@ -326,10 +412,11 @@ def rocm_aiter_grouped_topk(
topk_group,
renormalize,
scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
return total_topk_weights, total_topk_ids
return topk_weights, topk_ids
@ -354,7 +441,7 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None
expert_mask = expert_map if expert_map is not None else None
# w8a8 per-channel quantization
if (

View File

@ -1056,6 +1056,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=layer.num_fused_shared_experts,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN

View File

@ -1169,6 +1169,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
)
#

View File

@ -50,6 +50,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -203,7 +207,10 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
if config.n_shared_experts is None:
if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@ -233,11 +240,17 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not is_rocm_aiter_moe_enabled()
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
else None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -258,16 +271,15 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
else:
shared_output = None
final_hidden_states = fused_moe_out
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
if not is_rocm_aiter_moe_enabled():
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
@ -1316,7 +1328,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
else 0
),
num_redundant_experts=self.num_redundant_experts,
)
@ -1330,6 +1347,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
@ -1342,6 +1364,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fuse_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
@ -1366,65 +1390,115 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fuse_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
chunk_size = total // num_chunks
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_fuse_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
]
else:
weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
if is_pp_missing_parameter(name, self):
continue
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fuse_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer:
loaded_params.add(name)
return loaded_params