[ROCm][FEAT] Fuse DeepSeek shared experts into AITER fused_moe ops (#24097)

Signed-off-by: chenjun <junchen2@amd.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: valarLip <103567126+valarLip@users.noreply.github.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
kliuae 2025-10-16 10:41:34 +08:00 committed by GitHub
parent 0ecc553ee6
commit 1317034379
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 347 additions and 83 deletions

View File

@ -85,7 +85,7 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
else: else:
expected_test_local = base_experts expected_test_local = base_experts
test_local_experts, test_expert_map = determine_expert_map( test_local_experts, test_expert_map, _ = determine_expert_map(
ep_size=test_ep_size, ep_size=test_ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
global_num_experts=test_global_experts, global_num_experts=test_global_experts,
@ -116,7 +116,7 @@ def test_expert_placement_edge_cases(expert_placement_strategy, world_size):
"""Test edge cases for round_robin expert placement.""" """Test edge cases for round_robin expert placement."""
# Test case 1: ep_size = 1 (should return None for expert_map) # Test case 1: ep_size = 1 (should return None for expert_map)
local_num_experts, expert_map = determine_expert_map( local_num_experts, expert_map, _ = determine_expert_map(
ep_size=1, ep_size=1,
ep_rank=0, ep_rank=0,
global_num_experts=8, global_num_experts=8,
@ -217,7 +217,7 @@ def test_determine_expert_map_comprehensive():
expected_local, expected_local,
expected_map_pattern, expected_map_pattern,
) in test_cases: ) in test_cases:
local_num_experts, expert_map = determine_expert_map( local_num_experts, expert_map, _ = determine_expert_map(
ep_size=ep_size, ep_size=ep_size,
ep_rank=ep_rank, ep_rank=ep_rank,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,

View File

@ -217,7 +217,7 @@ def test_moe_permute_unpermute(
expert_map = None expert_map = None
n_local_expert = n_expert n_local_expert = n_expert
if ep_size != 1: if ep_size != 1:
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert) n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda() expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0) current_platform.seed_everything(0)

View File

@ -113,6 +113,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
@ -914,6 +915,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
in ("true", "1") in ("true", "1")
), ),
# Whether to use aiter fusion shared experts ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
in ("true", "1")
),
# use rocm skinny gemms # use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")

View File

@ -5,6 +5,7 @@ from abc import abstractmethod
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import partial
from typing import Literal, get_args, overload from typing import Literal, get_args, overload
import torch import torch
@ -12,7 +13,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
from vllm.config import get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import ( from vllm.distributed import (
get_dp_group, get_dp_group,
@ -39,6 +40,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled, is_rocm_aiter_moe_enabled,
) )
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
@ -87,7 +90,7 @@ else:
if is_rocm_aiter_moe_enabled(): if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk, rocm_aiter_grouped_topk as grouped_topk_aiter,
) )
else: else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
@ -634,6 +637,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num, zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type, zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
@ -860,7 +864,8 @@ def determine_expert_map(
ep_rank: int, ep_rank: int,
global_num_experts: int, global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear", expert_placement_strategy: ExpertPlacementStrategy = "linear",
) -> tuple[int, torch.Tensor | None]: num_fused_shared_experts: int = 0,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
""" """
Calculates how many experts should be assigned to each rank for EP and Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are creates a mapping from global to local expert index. Experts are
@ -882,10 +887,16 @@ def determine_expert_map(
(global_num_experts,) mapping from global to local index. (global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank. Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1. Returns None if ep_size is 1.
- expert_mask (Optional[torch.Tensor]): A tensor of shape
(global_num_experts + num_fused_shared_experts + 1,)
containing 1 for experts assigned to the current rank
and 0 for sentinel.
Returns None if ep_size is 1.
Used only when AITER MOE is enabled.
""" """
assert ep_size > 0 assert ep_size > 0
if ep_size == 1: if ep_size == 1:
return (global_num_experts, None) return (global_num_experts, None, None)
# Distribute experts as evenly as possible to each rank. # Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size base_experts = global_num_experts // ep_size
@ -914,7 +925,26 @@ def determine_expert_map(
f"'{expert_placement_strategy}', expected one of " f"'{expert_placement_strategy}', expected one of "
f"{get_args(ExpertPlacementStrategy)}" f"{get_args(ExpertPlacementStrategy)}"
) )
return (local_num_experts, expert_map)
expert_mask = None
if is_rocm_aiter_moe_enabled():
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
expert_mask[-1] = 0
expert_mask[:global_num_experts] = expert_map > -1
expert_map = torch.cat(
(
expert_map,
torch.tensor(
[local_num_experts + i for i in range(num_fused_shared_experts)],
dtype=torch.int32,
),
),
dim=0,
)
return (local_num_experts, expert_map, expert_mask)
def get_compressed_expert_map(expert_map: torch.Tensor) -> str: def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
@ -1040,6 +1070,7 @@ class FusedMoE(CustomOp):
zero_expert_num: int | None = 0, zero_expert_num: int | None = 0,
zero_expert_type: str | None = None, zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
@ -1096,6 +1127,22 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map: torch.Tensor | None = None self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None
and is_rocm_aiter_fusion_shared_expert_enabled()
else 0
)
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
and self.num_fused_shared_experts != 0
):
raise ValueError(
"n_shared_experts is only supported on ROCm aiter when "
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
)
# Determine expert maps # Determine expert maps
if self.use_ep: if self.use_ep:
if self.enable_eplb: if self.enable_eplb:
@ -1129,14 +1176,16 @@ class FusedMoE(CustomOp):
expert_placement_strategy = "linear" expert_placement_strategy = "linear"
self.expert_map: torch.Tensor | None self.expert_map: torch.Tensor | None
local_num_experts, expert_map = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy, expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
logger.info_once( logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert " "[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global" "placement strategy: %s. Local/global"
@ -1150,10 +1199,18 @@ class FusedMoE(CustomOp):
get_compressed_expert_map(self.expert_map), get_compressed_expert_map(self.expert_map),
) )
else: else:
self.local_num_experts, self.expert_map = (self.global_num_experts, None) self.local_num_experts, self.expert_map, self.expert_mask = (
self.global_num_experts,
None,
None,
)
self.top_k = top_k self.top_k = top_k
self._init_aiter_shared_experts_topK_buffer(
vllm_config=vllm_config, dp_size=dp_size_
)
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
@ -1327,13 +1384,18 @@ class FusedMoE(CustomOp):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
assert self.expert_map is not None assert self.expert_map is not None
with self.expert_map.device: with self.expert_map.device:
local_num_experts, expert_map = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
)
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
@ -1504,6 +1566,24 @@ class FusedMoE(CustomOp):
return expert_id return expert_id
return self.expert_map[expert_id].item() return self.expert_map[expert_id].item()
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if is_rocm_aiter_fusion_shared_expert_enabled():
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
@overload @overload
def weight_loader( def weight_loader(
self, self,
@ -1866,6 +1946,7 @@ class FusedMoE(CustomOp):
global_num_experts: int | None = None, global_num_experts: int | None = None,
zero_expert_num: int | None = None, zero_expert_num: int | None = None,
zero_expert_type: str | None = None, zero_expert_type: str | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Route the input hidden states to the top-k experts based on the Route the input hidden states to the top-k experts based on the
@ -1900,7 +1981,16 @@ class FusedMoE(CustomOp):
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk( if is_rocm_aiter_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
grouped_topk_aiter,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
@ -2119,7 +2209,9 @@ class FusedMoE(CustomOp):
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map, expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
else self.expert_mask,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
@ -2244,7 +2336,9 @@ class FusedMoE(CustomOp):
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map, expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
else self.expert_mask,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum from enum import IntEnum
from functools import cache from functools import cache, lru_cache
import torch import torch
@ -46,6 +46,69 @@ def is_rocm_aiter_moe_enabled() -> bool:
) )
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
)
aiter_topK_meta_data = None
@lru_cache(maxsize=1)
def init_aiter_topK_meta_data(
n_routed_experts: int,
n_shared_experts: int,
top_k: int,
tp_rank: int,
tp_size: int,
shared_experts_score: float = 1.0,
max_num_tokens: int = 32768,
is_EP: bool = False,
):
global aiter_topK_meta_data
fake_expertid = n_routed_experts + n_shared_experts
# all layers reuse same buffer
# This extra element when EP is enabled is used as a sentinel
# to mask out shared expert processing for tokens not owned by
# the current EP rank. This is necessary to avoid double-processing
# of shared experts.
total_topk_ids = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.int32,
device="cuda",
)
ns_topk_ids, s_topk_ids = total_topk_ids.split(
[top_k, n_shared_experts + is_EP], dim=1
)
shared_expert_ids = [n_routed_experts + i for i in range(n_shared_experts + is_EP)]
if is_EP:
s_topk_ids_list = [
[fake_expertid] * (n_shared_experts + is_EP)
] * max_num_tokens
for i in range(tp_rank, max_num_tokens, tp_size):
s_topk_ids_list[i] = shared_expert_ids
else:
s_topk_ids_list = [
list(range(n_routed_experts, fake_expertid))
] * max_num_tokens
s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda")
total_topk_weights = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.float32,
device="cuda",
)
ns_topk_weights, s_topk_weights = total_topk_weights.split(
[top_k, n_shared_experts + is_EP], dim=1
)
s_topk_weights.fill_(shared_experts_score)
assert aiter_topK_meta_data is None, "AITER topK meta data is already initialized"
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
def rocm_aiter_asm_moe_tkw1_impl( def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@ -300,9 +363,31 @@ def rocm_aiter_grouped_topk(
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None, e_score_correction_bias: torch.Tensor | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0] token = hidden_states.shape[0]
device = hidden_states.device device = hidden_states.device
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data "
"is called before this function."
)
total_topk_weights, total_topk_ids = aiter_topK_meta_data
assert total_topk_weights.shape[0] >= token, (
f"AITER topK meta data support {total_topk_weights.shape[0]} "
f"tokens which is determined by max_num_batched_tokens, "
f"but got {token} tokens now."
)
total_topk_weights = total_topk_weights[:token]
total_topk_ids = total_topk_ids[:token]
topk_weights, _ = total_topk_weights.split(
[topk, total_topk_weights.shape[1] - topk], dim=1
)
topk_ids, _ = total_topk_ids.split(
[topk, total_topk_ids.shape[1] - topk], dim=1
)
else:
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
@ -315,6 +400,7 @@ def rocm_aiter_grouped_topk(
num_expert_group, num_expert_group,
topk_group, topk_group,
renormalize, renormalize,
routed_scaling_factor=routed_scaling_factor,
) )
else: else:
assert scoring_func == "softmax" or scoring_func == "sigmoid" assert scoring_func == "softmax" or scoring_func == "sigmoid"
@ -326,10 +412,11 @@ def rocm_aiter_grouped_topk(
topk_group, topk_group,
renormalize, renormalize,
scoring_func, scoring_func,
routed_scaling_factor=routed_scaling_factor,
) )
if routed_scaling_factor != 1.0: if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
topk_weights = topk_weights * routed_scaling_factor return total_topk_weights, total_topk_ids
return topk_weights, topk_ids return topk_weights, topk_ids
@ -354,7 +441,7 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32) topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None expert_mask = expert_map if expert_map is not None else None
# w8a8 per-channel quantization # w8a8 per-channel quantization
if ( if (

View File

@ -1056,6 +1056,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
num_fused_shared_experts=layer.num_fused_shared_experts,
) )
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN

View File

@ -1169,6 +1169,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num, zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type, zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
) )
# #

View File

@ -50,6 +50,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
@ -203,7 +207,10 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_start + self.n_local_physical_experts self.physical_expert_start + self.n_local_physical_experts
) )
if config.n_shared_experts is None: if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
self.shared_experts = None self.shared_experts = None
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@ -233,11 +240,17 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul # we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0, # aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not is_rocm_aiter_moe_enabled()
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
else None,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -258,15 +271,14 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out shared_output, final_hidden_states = fused_moe_out
else: if self.shared_experts is None:
shared_output = None assert shared_output is None
final_hidden_states = fused_moe_out
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
if not is_rocm_aiter_moe_enabled():
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None: elif self.shared_experts is not None:
assert shared_output is not None assert shared_output is not None
@ -1316,7 +1328,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
else 0
),
num_redundant_experts=self.num_redundant_experts, num_redundant_experts=self.num_redundant_experts,
) )
@ -1330,6 +1347,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
@ -1342,6 +1364,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if ("mlp.experts." in name) and name not in params_dict:
continue continue
if is_fuse_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal # QKV fusion is optional, fall back to normal
@ -1366,9 +1390,55 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
break break
else: else:
is_expert_weight = False is_expert_weight = False
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fuse_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
chunk_size = total // num_chunks
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
if is_fuse_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
]
else:
weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in chunk_name:
continue continue
# Anyway, this is an expert weight and should not be # Anyway, this is an expert weight and should not be
@ -1377,28 +1447,31 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
# Do not modify `name` since the loop may continue here # Do not modify `name` since the loop may continue here
# Instead, create a new variable # Instead, create a new variable
name_mapped = name.replace(weight_name, param_name) name_mapped = chunk_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self): if is_pp_missing_parameter(name_mapped, self):
continue continue
param = params_dict[name_mapped] param = params_dict[name_mapped]
# We should ask the weight loader to return success or not # We should ask the weight loader to return success or
# here since otherwise we may skip experts with other # not here since otherwise we may skip experts with
# available replicas. # other available replicas.
weight_loader = typing.cast( weight_loader = typing.cast(
Callable[..., bool], param.weight_loader Callable[..., bool], param.weight_loader
) )
success = weight_loader( success = weight_loader(
param, param,
loaded_weight, weight_to_load,
name_mapped, name_mapped,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
return_success=True, return_success=True,
) )
if success: if success:
if not is_fuse_shared_experts_layer:
name = name_mapped name = name_mapped
else:
loaded_params.add(name_mapped)
break break
else: else:
if is_expert_weight: if is_expert_weight:
@ -1424,6 +1497,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params