diff --git a/vllm/model_executor/layers/fused_moe/aiter_shared_expert_fusion.py b/vllm/model_executor/layers/fused_moe/aiter_shared_expert_fusion.py new file mode 100644 index 0000000000000..c6e1a69127c02 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/aiter_shared_expert_fusion.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Helper class for ROCm AITER shared expert fusion in FusedMoE. + +This module encapsulates the scattered logic for AITER shared expert fusion, +providing a cleaner interface for the FusedMoE layer. It handles: +- Capability checks for AITER fused MoE and shared expert fusion +- Computing and validating num_fused_shared_experts +- Initializing topK metadata buffers +- Providing expert map augmentation for determine_expert_map() +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.config import VllmConfig + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +@dataclass +class AiterSharedExpertFusion: + """ + Encapsulates ROCm AITER shared expert fusion logic for FusedMoE. + + This helper class manages the state and operations related to AITER's + shared expert fusion feature, reducing scattered if-branches in the + FusedMoE layer. + + Attributes: + rocm_aiter_fmoe_enabled: Whether ROCm AITER fused MoE is enabled. + aiter_shared_expert_enabled: Whether AITER shared expert fusion + is enabled. + num_fused_shared_experts: Number of shared experts to fuse (0 if + fusion is disabled). + """ + + rocm_aiter_fmoe_enabled: bool + aiter_shared_expert_enabled: bool + num_fused_shared_experts: int + + @classmethod + def create(cls, n_shared_experts: int | None) -> "AiterSharedExpertFusion": + """ + Factory method to create an AiterSharedExpertFusion instance. + + Args: + n_shared_experts: Number of shared experts from the model config, + or None if not specified. + + Returns: + An AiterSharedExpertFusion instance with properly initialized + state. + + Raises: + ValueError: If n_shared_experts is provided but AITER shared + expert fusion is not enabled. + """ + rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + aiter_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + + # Compute num_fused_shared_experts + num_fused_shared_experts = ( + n_shared_experts + if n_shared_experts is not None and aiter_shared_expert_enabled + else 0 + ) + + # Validate configuration + if ( + not aiter_shared_expert_enabled + and n_shared_experts is not None + and n_shared_experts > 0 + ): + raise ValueError( + "n_shared_experts is only supported on ROCm aiter when " + "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled" + ) + + return cls( + rocm_aiter_fmoe_enabled=rocm_aiter_fmoe_enabled, + aiter_shared_expert_enabled=aiter_shared_expert_enabled, + num_fused_shared_experts=num_fused_shared_experts, + ) + + @property + def is_enabled(self) -> bool: + """Check if AITER fused MoE is enabled.""" + return self.rocm_aiter_fmoe_enabled + + @property + def is_shared_expert_fusion_enabled(self) -> bool: + """Check if shared expert fusion is enabled.""" + return self.aiter_shared_expert_enabled + + @property + def has_fused_shared_experts(self) -> bool: + """Check if there are shared experts to fuse.""" + return self.num_fused_shared_experts > 0 + + def get_expert_map_kwargs(self) -> dict: + """ + Get additional kwargs for determine_expert_map(). + + Returns: + Dictionary with num_fused_shared_experts and return_expert_mask + arguments. + """ + return { + "num_fused_shared_experts": self.num_fused_shared_experts, + "return_expert_mask": self.rocm_aiter_fmoe_enabled, + } + + def validate_expert_mask(self, expert_mask: torch.Tensor | None) -> None: + """ + Validate that expert_mask contains only 0s and 1s when AITER is enabled. + + Args: + expert_mask: The expert mask tensor to validate. + + Raises: + AssertionError: If expert_mask contains values other than 0 and 1. + """ + if self.rocm_aiter_fmoe_enabled and expert_mask is not None: + assert torch.all((expert_mask == 0) | (expert_mask == 1)), ( + "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." + ) + + def init_topk_buffers( + self, + layer: "FusedMoE", + vllm_config: VllmConfig, + dp_size: int, + ) -> None: + """ + Initialize AITER topK metadata buffers if shared expert fusion is enabled. + + This method also updates layer.local_num_experts to include fused + shared experts. + + Args: + layer: The FusedMoE layer to initialize buffers for. + vllm_config: The vLLM configuration. + dp_size: Data parallel size. + """ + if self.has_fused_shared_experts: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + init_aiter_topK_meta_data, + ) + + init_aiter_topK_meta_data( + n_routed_experts=layer.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=layer.top_k, + tp_rank=layer.ep_rank if layer.use_ep else layer.tp_rank, + tp_size=layer.ep_size if layer.use_ep else layer.tp_size, + shared_experts_score=1.0, + max_num_tokens=( + vllm_config.scheduler_config.max_num_batched_tokens * dp_size + ), + is_EP=layer.use_ep, + ) + layer.local_num_experts += self.num_fused_shared_experts + + def get_expert_map( + self, + expert_map: torch.Tensor | None, + expert_mask: torch.Tensor | None, + ) -> torch.Tensor | None: + """ + Get the appropriate expert map based on AITER state. + + When AITER fused MoE is enabled, returns expert_mask instead of + expert_map. + + Args: + expert_map: The standard expert map tensor. + expert_mask: The AITER-specific expert mask tensor. + + Returns: + expert_mask if AITER is enabled, otherwise expert_map. + """ + return expert_mask if self.rocm_aiter_fmoe_enabled else expert_map diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2e7267d56d838..bc4588879869b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -12,8 +12,7 @@ import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm._aiter_ops import rocm_aiter_ops -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( get_dp_group, @@ -26,15 +25,15 @@ from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.aiter_shared_expert_fusion import ( + AiterSharedExpertFusion, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - init_aiter_topK_meta_data, -) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, @@ -434,25 +433,8 @@ class FusedMoE(CustomOp): vllm_config.parallel_config.expert_placement_strategy ) - # ROCm aiter shared experts fusion - self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - self.aiter_fmoe_shared_expert_enabled = ( - rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() - ) - - self.num_fused_shared_experts = ( - n_shared_experts - if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled - else 0 - ) - if ( - not self.aiter_fmoe_shared_expert_enabled - and self.num_fused_shared_experts != 0 - ): - raise ValueError( - "n_shared_experts is only supported on ROCm aiter when " - "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled" - ) + # ROCm aiter shared experts fusion - encapsulated in helper class + self._aiter_fusion = AiterSharedExpertFusion.create(n_shared_experts) # Determine expert maps if self.use_ep: @@ -480,8 +462,7 @@ class FusedMoE(CustomOp): ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, expert_placement_strategy=self.expert_placement_strategy, - num_fused_shared_experts=self.num_fused_shared_experts, - return_expert_mask=self.rocm_aiter_fmoe_enabled, + **self._aiter_fusion.get_expert_map_kwargs(), ) self.local_num_experts = local_num_experts self.register_buffer("_expert_map", expert_map) @@ -508,13 +489,11 @@ class FusedMoE(CustomOp): self.top_k = top_k - self._init_aiter_shared_experts_topK_buffer( - vllm_config=vllm_config, dp_size=dp_size_ + self._aiter_fusion.init_topk_buffers( + layer=self, vllm_config=vllm_config, dp_size=dp_size_ ) - if self.use_ep and self.rocm_aiter_fmoe_enabled: - assert self.expert_mask is None or torch.all( - (expert_mask == 0) | (expert_mask == 1) - ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." + if self.use_ep: + self._aiter_fusion.validate_expert_mask(self.expert_mask) assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -838,15 +817,15 @@ class FusedMoE(CustomOp): ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, expert_placement_strategy=self.expert_placement_strategy, - num_fused_shared_experts=self.num_fused_shared_experts, - return_expert_mask=self.rocm_aiter_fmoe_enabled, + **self._aiter_fusion.get_expert_map_kwargs(), ) self.local_num_experts = local_num_experts self.register_buffer("_expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) self._maybe_init_expert_routing_tables() - if self.aiter_fmoe_shared_expert_enabled: - self._init_aiter_shared_experts_topK_buffer( + if self._aiter_fusion.is_shared_expert_fusion_enabled: + self._aiter_fusion.init_topk_buffers( + layer=self, vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size, ) @@ -1063,22 +1042,21 @@ class FusedMoE(CustomOp): return expert_id return self._expert_map[expert_id].item() - def _init_aiter_shared_experts_topK_buffer( - self, vllm_config: VllmConfig, dp_size: int - ): - if self.num_fused_shared_experts > 0: - init_aiter_topK_meta_data( - n_routed_experts=self.global_num_experts, - n_shared_experts=self.num_fused_shared_experts, - top_k=self.top_k, - tp_rank=self.ep_rank if self.use_ep else self.tp_rank, - tp_size=self.ep_size if self.use_ep else self.tp_size, - shared_experts_score=1.0, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens - * dp_size, - is_EP=self.use_ep, - ) - self.local_num_experts += self.num_fused_shared_experts + # Backward-compatible properties for AITER shared expert fusion + @property + def rocm_aiter_fmoe_enabled(self) -> bool: + """Whether ROCm AITER fused MoE is enabled.""" + return self._aiter_fusion.rocm_aiter_fmoe_enabled + + @property + def aiter_fmoe_shared_expert_enabled(self) -> bool: + """Whether AITER shared expert fusion is enabled.""" + return self._aiter_fusion.aiter_shared_expert_enabled + + @property + def num_fused_shared_experts(self) -> int: + """Number of shared experts to fuse (0 if fusion is disabled).""" + return self._aiter_fusion.num_fused_shared_experts @overload def weight_loader( @@ -1583,12 +1561,12 @@ class FusedMoE(CustomOp): elif self.use_grouped_topk and valid_grouping(): assert self.topk_group is not None assert self.num_expert_group is not None - if rocm_aiter_ops.is_fused_moe_enabled(): - if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): - assert self.num_fused_shared_experts == 0 + if self._aiter_fusion.is_enabled: + if not self._aiter_fusion.is_shared_expert_fusion_enabled: + assert self._aiter_fusion.num_fused_shared_experts == 0 grouped_topk_impl = partial( rocm_aiter_grouped_topk, - num_fused_shared_experts=self.num_fused_shared_experts, + num_fused_shared_experts=self._aiter_fusion.num_fused_shared_experts, topk=self.top_k, renormalize=self.renormalize, num_expert_group=self.num_expert_group, @@ -1735,9 +1713,7 @@ class FusedMoE(CustomOp): @property def expert_map(self) -> torch.Tensor | None: - return ( - self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask - ) + return self._aiter_fusion.get_expert_map(self._expert_map, self.expert_mask) def forward_cuda( self,