Merge abeaa592f77f6287e63f4b611add0d0a25c9979e into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
ゆり 2025-12-25 00:06:39 +00:00 committed by GitHub
commit e49178f47a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 225 additions and 59 deletions

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Helper class for ROCm AITER shared expert fusion in FusedMoE.
This module encapsulates the scattered logic for AITER shared expert fusion,
providing a cleaner interface for the FusedMoE layer. It handles:
- Capability checks for AITER fused MoE and shared expert fusion
- Computing and validating num_fused_shared_experts
- Initializing topK metadata buffers
- Providing expert map augmentation for determine_expert_map()
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@dataclass
class AiterSharedExpertFusion:
"""
Encapsulates ROCm AITER shared expert fusion logic for FusedMoE.
This helper class manages the state and operations related to AITER's
shared expert fusion feature, reducing scattered if-branches in the
FusedMoE layer.
Attributes:
rocm_aiter_fmoe_enabled: Whether ROCm AITER fused MoE is enabled.
aiter_shared_expert_enabled: Whether AITER shared expert fusion
is enabled.
num_fused_shared_experts: Number of shared experts to fuse (0 if
fusion is disabled).
"""
rocm_aiter_fmoe_enabled: bool
aiter_shared_expert_enabled: bool
num_fused_shared_experts: int
@classmethod
def create(cls, n_shared_experts: int | None) -> "AiterSharedExpertFusion":
"""
Factory method to create an AiterSharedExpertFusion instance.
Args:
n_shared_experts: Number of shared experts from the model config,
or None if not specified.
Returns:
An AiterSharedExpertFusion instance with properly initialized
state.
Raises:
ValueError: If n_shared_experts is provided but AITER shared
expert fusion is not enabled.
"""
rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
aiter_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
# Compute num_fused_shared_experts
num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None and aiter_shared_expert_enabled
else 0
)
# Validate configuration
if (
not aiter_shared_expert_enabled
and n_shared_experts is not None
and n_shared_experts > 0
):
raise ValueError(
"n_shared_experts is only supported on ROCm aiter when "
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
)
return cls(
rocm_aiter_fmoe_enabled=rocm_aiter_fmoe_enabled,
aiter_shared_expert_enabled=aiter_shared_expert_enabled,
num_fused_shared_experts=num_fused_shared_experts,
)
@property
def is_enabled(self) -> bool:
"""Check if AITER fused MoE is enabled."""
return self.rocm_aiter_fmoe_enabled
@property
def is_shared_expert_fusion_enabled(self) -> bool:
"""Check if shared expert fusion is enabled."""
return self.aiter_shared_expert_enabled
@property
def has_fused_shared_experts(self) -> bool:
"""Check if there are shared experts to fuse."""
return self.num_fused_shared_experts > 0
def get_expert_map_kwargs(self) -> dict:
"""
Get additional kwargs for determine_expert_map().
Returns:
Dictionary with num_fused_shared_experts and return_expert_mask
arguments.
"""
return {
"num_fused_shared_experts": self.num_fused_shared_experts,
"return_expert_mask": self.rocm_aiter_fmoe_enabled,
}
def validate_expert_mask(self, expert_mask: torch.Tensor | None) -> None:
"""
Validate that expert_mask contains only 0s and 1s when AITER is enabled.
Args:
expert_mask: The expert mask tensor to validate.
Raises:
AssertionError: If expert_mask contains values other than 0 and 1.
"""
if self.rocm_aiter_fmoe_enabled and expert_mask is not None:
assert torch.all((expert_mask == 0) | (expert_mask == 1)), (
"Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
)
def init_topk_buffers(
self,
layer: "FusedMoE",
vllm_config: VllmConfig,
dp_size: int,
) -> None:
"""
Initialize AITER topK metadata buffers if shared expert fusion is enabled.
This method also updates layer.local_num_experts to include fused
shared experts.
Args:
layer: The FusedMoE layer to initialize buffers for.
vllm_config: The vLLM configuration.
dp_size: Data parallel size.
"""
if self.has_fused_shared_experts:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
init_aiter_topK_meta_data(
n_routed_experts=layer.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=layer.top_k,
tp_rank=layer.ep_rank if layer.use_ep else layer.tp_rank,
tp_size=layer.ep_size if layer.use_ep else layer.tp_size,
shared_experts_score=1.0,
max_num_tokens=(
vllm_config.scheduler_config.max_num_batched_tokens * dp_size
),
is_EP=layer.use_ep,
)
layer.local_num_experts += self.num_fused_shared_experts
def get_expert_map(
self,
expert_map: torch.Tensor | None,
expert_mask: torch.Tensor | None,
) -> torch.Tensor | None:
"""
Get the appropriate expert map based on AITER state.
When AITER fused MoE is enabled, returns expert_mask instead of
expert_map.
Args:
expert_map: The standard expert map tensor.
expert_mask: The AITER-specific expert mask tensor.
Returns:
expert_mask if AITER is enabled, otherwise expert_map.
"""
return expert_mask if self.rocm_aiter_fmoe_enabled else expert_map

View File

@ -12,8 +12,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
@ -26,15 +25,15 @@ from vllm.distributed.eplb.eplb_state import EplbState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.aiter_shared_expert_fusion import (
AiterSharedExpertFusion,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
@ -434,25 +433,8 @@ class FusedMoE(CustomOp):
vllm_config.parallel_config.expert_placement_strategy
)
# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.aiter_fmoe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
if (
not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0
):
raise ValueError(
"n_shared_experts is only supported on ROCm aiter when "
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
)
# ROCm aiter shared experts fusion - encapsulated in helper class
self._aiter_fusion = AiterSharedExpertFusion.create(n_shared_experts)
# Determine expert maps
if self.use_ep:
@ -480,8 +462,7 @@ class FusedMoE(CustomOp):
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
**self._aiter_fusion.get_expert_map_kwargs(),
)
self.local_num_experts = local_num_experts
self.register_buffer("_expert_map", expert_map)
@ -508,13 +489,11 @@ class FusedMoE(CustomOp):
self.top_k = top_k
self._init_aiter_shared_experts_topK_buffer(
vllm_config=vllm_config, dp_size=dp_size_
self._aiter_fusion.init_topk_buffers(
layer=self, vllm_config=vllm_config, dp_size=dp_size_
)
if self.use_ep and self.rocm_aiter_fmoe_enabled:
assert self.expert_mask is None or torch.all(
(expert_mask == 0) | (expert_mask == 1)
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
if self.use_ep:
self._aiter_fusion.validate_expert_mask(self.expert_mask)
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
@ -838,15 +817,15 @@ class FusedMoE(CustomOp):
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
**self._aiter_fusion.get_expert_map_kwargs(),
)
self.local_num_experts = local_num_experts
self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
if self._aiter_fusion.is_shared_expert_fusion_enabled:
self._aiter_fusion.init_topk_buffers(
layer=self,
vllm_config=get_current_vllm_config(),
dp_size=get_dp_group().world_size,
)
@ -1063,22 +1042,21 @@ class FusedMoE(CustomOp):
return expert_id
return self._expert_map[expert_id].item()
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
# Backward-compatible properties for AITER shared expert fusion
@property
def rocm_aiter_fmoe_enabled(self) -> bool:
"""Whether ROCm AITER fused MoE is enabled."""
return self._aiter_fusion.rocm_aiter_fmoe_enabled
@property
def aiter_fmoe_shared_expert_enabled(self) -> bool:
"""Whether AITER shared expert fusion is enabled."""
return self._aiter_fusion.aiter_shared_expert_enabled
@property
def num_fused_shared_experts(self) -> int:
"""Number of shared experts to fuse (0 if fusion is disabled)."""
return self._aiter_fusion.num_fused_shared_experts
@overload
def weight_loader(
@ -1583,12 +1561,12 @@ class FusedMoE(CustomOp):
elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None
assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
if self._aiter_fusion.is_enabled:
if not self._aiter_fusion.is_shared_expert_fusion_enabled:
assert self._aiter_fusion.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts,
num_fused_shared_experts=self._aiter_fusion.num_fused_shared_experts,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
@ -1735,9 +1713,7 @@ class FusedMoE(CustomOp):
@property
def expert_map(self) -> torch.Tensor | None:
return (
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
)
return self._aiter_fusion.get_expert_map(self._expert_map, self.expert_mask)
def forward_cuda(
self,