diff --git a/tests/distributed/test_expert_placement.py b/tests/distributed/test_expert_placement.py index cb9c8f507404..8b3a64b9c134 100644 --- a/tests/distributed/test_expert_placement.py +++ b/tests/distributed/test_expert_placement.py @@ -85,7 +85,7 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size): else: expected_test_local = base_experts - test_local_experts, test_expert_map = determine_expert_map( + test_local_experts, test_expert_map, _ = determine_expert_map( ep_size=test_ep_size, ep_rank=ep_rank, global_num_experts=test_global_experts, @@ -116,7 +116,7 @@ def test_expert_placement_edge_cases(expert_placement_strategy, world_size): """Test edge cases for round_robin expert placement.""" # Test case 1: ep_size = 1 (should return None for expert_map) - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, _ = determine_expert_map( ep_size=1, ep_rank=0, global_num_experts=8, @@ -217,7 +217,7 @@ def test_determine_expert_map_comprehensive(): expected_local, expected_map_pattern, ) in test_cases: - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, _ = determine_expert_map( ep_size=ep_size, ep_rank=ep_rank, global_num_experts=global_num_experts, diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index da9fe33a1c62..ba1f657b3ecd 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -217,7 +217,7 @@ def test_moe_permute_unpermute( expert_map = None n_local_expert = n_expert if ep_size != 1: - n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert) + n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert) expert_map = expert_map.cuda() start_expert = n_local_expert * ep_rank current_platform.seed_everything(0) diff --git a/vllm/envs.py b/vllm/envs.py index 6f40209dd000..7dcfabe3e044 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -113,6 +113,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -914,6 +915,12 @@ environment_variables: dict[str, Callable[[], Any]] = { os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1") ), + # Whether to use aiter fusion shared experts ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() + in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9b117f3b5d41..de4ed58e0cf4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,6 +5,7 @@ from abc import abstractmethod from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum +from functools import partial from typing import Literal, get_args, overload import torch @@ -12,7 +13,7 @@ import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( get_dp_group, @@ -39,6 +40,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + init_aiter_topK_meta_data, + is_rocm_aiter_fusion_shared_expert_enabled, is_rocm_aiter_moe_enabled, ) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator @@ -87,7 +90,7 @@ else: if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk, + rocm_aiter_grouped_topk as grouped_topk_aiter, ) else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk @@ -634,6 +637,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): global_num_experts=global_num_experts, zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, ) if self.rocm_aiter_moe_enabled: @@ -860,7 +864,8 @@ def determine_expert_map( ep_rank: int, global_num_experts: int, expert_placement_strategy: ExpertPlacementStrategy = "linear", -) -> tuple[int, torch.Tensor | None]: + num_fused_shared_experts: int = 0, +) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: """ Calculates how many experts should be assigned to each rank for EP and creates a mapping from global to local expert index. Experts are @@ -882,10 +887,16 @@ def determine_expert_map( (global_num_experts,) mapping from global to local index. Contains -1 for experts not assigned to the current rank. Returns None if ep_size is 1. + - expert_mask (Optional[torch.Tensor]): A tensor of shape + (global_num_experts + num_fused_shared_experts + 1,) + containing 1 for experts assigned to the current rank + and 0 for sentinel. + Returns None if ep_size is 1. + Used only when AITER MOE is enabled. """ assert ep_size > 0 if ep_size == 1: - return (global_num_experts, None) + return (global_num_experts, None, None) # Distribute experts as evenly as possible to each rank. base_experts = global_num_experts // ep_size @@ -914,7 +925,26 @@ def determine_expert_map( f"'{expert_placement_strategy}', expected one of " f"{get_args(ExpertPlacementStrategy)}" ) - return (local_num_experts, expert_map) + + expert_mask = None + if is_rocm_aiter_moe_enabled(): + expert_mask = torch.ones( + (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 + ) + expert_mask[-1] = 0 + expert_mask[:global_num_experts] = expert_map > -1 + expert_map = torch.cat( + ( + expert_map, + torch.tensor( + [local_num_experts + i for i in range(num_fused_shared_experts)], + dtype=torch.int32, + ), + ), + dim=0, + ) + + return (local_num_experts, expert_map, expert_mask) def get_compressed_expert_map(expert_map: torch.Tensor) -> str: @@ -1040,6 +1070,7 @@ class FusedMoE(CustomOp): zero_expert_num: int | None = 0, zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, + n_shared_experts: int | None = None, ): super().__init__() if params_dtype is None: @@ -1096,6 +1127,22 @@ class FusedMoE(CustomOp): self.logical_to_physical_map: torch.Tensor | None = None self.logical_replica_count: torch.Tensor | None = None + # ROCm aiter shared experts fusion + self.num_fused_shared_experts = ( + n_shared_experts + if n_shared_experts is not None + and is_rocm_aiter_fusion_shared_expert_enabled() + else 0 + ) + if ( + not is_rocm_aiter_fusion_shared_expert_enabled() + and self.num_fused_shared_experts != 0 + ): + raise ValueError( + "n_shared_experts is only supported on ROCm aiter when " + "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled" + ) + # Determine expert maps if self.use_ep: if self.enable_eplb: @@ -1129,14 +1176,16 @@ class FusedMoE(CustomOp): expert_placement_strategy = "linear" self.expert_map: torch.Tensor | None - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, + num_fused_shared_experts=self.num_fused_shared_experts, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) + self.register_buffer("expert_mask", expert_mask) logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Expert " "placement strategy: %s. Local/global" @@ -1150,10 +1199,18 @@ class FusedMoE(CustomOp): get_compressed_expert_map(self.expert_map), ) else: - self.local_num_experts, self.expert_map = (self.global_num_experts, None) + self.local_num_experts, self.expert_map, self.expert_mask = ( + self.global_num_experts, + None, + None, + ) self.top_k = top_k + self._init_aiter_shared_experts_topK_buffer( + vllm_config=vllm_config, dp_size=dp_size_ + ) + assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size @@ -1327,13 +1384,18 @@ class FusedMoE(CustomOp): # ep_size and ep_rank should already be updated assert self.expert_map is not None with self.expert_map.device: - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, + num_fused_shared_experts=self.num_fused_shared_experts, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) + self.register_buffer("expert_mask", expert_mask) + self._init_aiter_shared_experts_topK_buffer( + vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size + ) def _load_per_tensor_weight_scale( self, @@ -1504,6 +1566,24 @@ class FusedMoE(CustomOp): return expert_id return self.expert_map[expert_id].item() + def _init_aiter_shared_experts_topK_buffer( + self, vllm_config: VllmConfig, dp_size: int + ): + if is_rocm_aiter_fusion_shared_expert_enabled(): + if self.num_fused_shared_experts > 0: + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens + * dp_size, + is_EP=self.use_ep, + ) + self.local_num_experts += self.num_fused_shared_experts + @overload def weight_loader( self, @@ -1866,6 +1946,7 @@ class FusedMoE(CustomOp): global_num_experts: int | None = None, zero_expert_num: int | None = None, zero_expert_type: str | None = None, + num_fused_shared_experts: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the @@ -1900,7 +1981,16 @@ class FusedMoE(CustomOp): if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( + if is_rocm_aiter_moe_enabled(): + if not is_rocm_aiter_fusion_shared_expert_enabled(): + assert num_fused_shared_experts == 0 + grouped_topk_impl = partial( + grouped_topk_aiter, + num_fused_shared_experts=num_fused_shared_experts, + ) + else: + grouped_topk_impl = grouped_topk + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, @@ -2119,7 +2209,9 @@ class FusedMoE(CustomOp): renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, - expert_map=self.expert_map, + expert_map=self.expert_map + if not is_rocm_aiter_moe_enabled() + else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, @@ -2244,7 +2336,9 @@ class FusedMoE(CustomOp): renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, - expert_map=self.expert_map, + expert_map=self.expert_map + if not is_rocm_aiter_moe_enabled() + else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 921e0b24b9ef..b572baecd753 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum -from functools import cache +from functools import cache, lru_cache import torch @@ -46,6 +46,69 @@ def is_rocm_aiter_moe_enabled() -> bool: ) +@cache +def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: + return ( + envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled() + ) + + +aiter_topK_meta_data = None + + +@lru_cache(maxsize=1) +def init_aiter_topK_meta_data( + n_routed_experts: int, + n_shared_experts: int, + top_k: int, + tp_rank: int, + tp_size: int, + shared_experts_score: float = 1.0, + max_num_tokens: int = 32768, + is_EP: bool = False, +): + global aiter_topK_meta_data + fake_expertid = n_routed_experts + n_shared_experts + + # all layers reuse same buffer + # This extra element when EP is enabled is used as a sentinel + # to mask out shared expert processing for tokens not owned by + # the current EP rank. This is necessary to avoid double-processing + # of shared experts. + total_topk_ids = torch.empty( + (max_num_tokens, top_k + n_shared_experts + is_EP), + dtype=torch.int32, + device="cuda", + ) + ns_topk_ids, s_topk_ids = total_topk_ids.split( + [top_k, n_shared_experts + is_EP], dim=1 + ) + shared_expert_ids = [n_routed_experts + i for i in range(n_shared_experts + is_EP)] + if is_EP: + s_topk_ids_list = [ + [fake_expertid] * (n_shared_experts + is_EP) + ] * max_num_tokens + for i in range(tp_rank, max_num_tokens, tp_size): + s_topk_ids_list[i] = shared_expert_ids + else: + s_topk_ids_list = [ + list(range(n_routed_experts, fake_expertid)) + ] * max_num_tokens + s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda") + + total_topk_weights = torch.empty( + (max_num_tokens, top_k + n_shared_experts + is_EP), + dtype=torch.float32, + device="cuda", + ) + ns_topk_weights, s_topk_weights = total_topk_weights.split( + [top_k, n_shared_experts + is_EP], dim=1 + ) + s_topk_weights.fill_(shared_experts_score) + assert aiter_topK_meta_data is None, "AITER topK meta data is already initialized" + aiter_topK_meta_data = (total_topk_weights, total_topk_ids) + + def rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -300,11 +363,33 @@ def rocm_aiter_grouped_topk( scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, + num_fused_shared_experts: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device - topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) + if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + assert aiter_topK_meta_data is not None, ( + "AITER topK meta data is not initialized. " + "Please ensure that init_aiter_topK_meta_data " + "is called before this function." + ) + total_topk_weights, total_topk_ids = aiter_topK_meta_data + assert total_topk_weights.shape[0] >= token, ( + f"AITER topK meta data support {total_topk_weights.shape[0]} " + f"tokens which is determined by max_num_batched_tokens, " + f"but got {token} tokens now." + ) + total_topk_weights = total_topk_weights[:token] + total_topk_ids = total_topk_ids[:token] + topk_weights, _ = total_topk_weights.split( + [topk, total_topk_weights.shape[1] - topk], dim=1 + ) + topk_ids, _ = total_topk_ids.split( + [topk, total_topk_ids.shape[1] - topk], dim=1 + ) + else: + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( @@ -315,6 +400,7 @@ def rocm_aiter_grouped_topk( num_expert_group, topk_group, renormalize, + routed_scaling_factor=routed_scaling_factor, ) else: assert scoring_func == "softmax" or scoring_func == "sigmoid" @@ -326,10 +412,11 @@ def rocm_aiter_grouped_topk( topk_group, renormalize, scoring_func, + routed_scaling_factor=routed_scaling_factor, ) - if routed_scaling_factor != 1.0: - topk_weights = topk_weights * routed_scaling_factor + if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + return total_topk_weights, total_topk_ids return topk_weights, topk_ids @@ -354,7 +441,7 @@ def rocm_aiter_fused_experts( topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) - expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None + expert_mask = expert_map if expert_map is not None else None # w8a8 per-channel quantization if ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 3cc726aafd29..1f4a76452f96 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1056,6 +1056,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + num_fused_shared_experts=layer.num_fused_shared_experts, ) per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 02b1896a8996..5967ee9b6e3f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1169,6 +1169,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): global_num_experts=global_num_experts, zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, ) # diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f33ed735f429..5b55b685dacf 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,6 +50,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_fusion_shared_expert_enabled, + is_rocm_aiter_moe_enabled, +) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -203,7 +207,10 @@ class DeepseekV2MoE(nn.Module): self.physical_expert_start + self.n_local_physical_experts ) - if config.n_shared_experts is None: + if ( + config.n_shared_experts is None + or is_rocm_aiter_fusion_shared_expert_enabled() + ): self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -233,11 +240,17 @@ class DeepseekV2MoE(nn.Module): prefix=f"{prefix}.experts", scoring_func=config.scoring_func, # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not is_rocm_aiter_moe_enabled() + else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if is_rocm_aiter_fusion_shared_expert_enabled() + else None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -258,16 +271,15 @@ class DeepseekV2MoE(nn.Module): hidden_states=hidden_states, router_logits=router_logits ) - if self.shared_experts is not None: - shared_output, final_hidden_states = fused_moe_out - else: - shared_output = None - final_hidden_states = fused_moe_out + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - final_hidden_states *= self.routed_scaling_factor + if not is_rocm_aiter_moe_enabled(): + final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None shared_output *= 1.0 / self.routed_scaling_factor @@ -1316,7 +1328,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if is_rocm_aiter_fusion_shared_expert_enabled() + else 0 + ), num_redundant_experts=self.num_redundant_experts, ) @@ -1330,6 +1347,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR if spec_layer is not None: continue # skip spec decode layers for main model + is_fuse_shared_experts_layer = ( + is_rocm_aiter_fusion_shared_expert_enabled() + and ("mlp.shared_experts" in name) + ) + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: @@ -1342,6 +1364,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue + if is_fuse_shared_experts_layer: + continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal @@ -1366,65 +1390,115 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR break else: is_expert_weight = False - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name_mapped, self): - continue - - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast( - Callable[..., bool], param.weight_loader + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fuse_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" ) - success = weight_loader( - param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - name = name_mapped - break - else: - if is_expert_weight: - # We've checked that this is an expert weight - # However it's not mapped locally to this rank - # So we simply skip it - continue + chunk_size = total // num_chunks - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue + if is_fuse_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[ + j * chunk_size : (j + 1) * chunk_size, : + ] + else: + weight_to_load = loaded_weight[ + :, j * chunk_size : (j + 1) * chunk_size + ] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) - if is_pp_missing_parameter(name, self): - continue + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fuse_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if not is_fuse_shared_experts_layer: + loaded_params.add(name) return loaded_params