mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 08:21:24 +08:00
Merge abeaa592f77f6287e63f4b611add0d0a25c9979e into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
e49178f47a
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Helper class for ROCm AITER shared expert fusion in FusedMoE.
|
||||
|
||||
This module encapsulates the scattered logic for AITER shared expert fusion,
|
||||
providing a cleaner interface for the FusedMoE layer. It handles:
|
||||
- Capability checks for AITER fused MoE and shared expert fusion
|
||||
- Computing and validating num_fused_shared_experts
|
||||
- Initializing topK metadata buffers
|
||||
- Providing expert map augmentation for determine_expert_map()
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterSharedExpertFusion:
|
||||
"""
|
||||
Encapsulates ROCm AITER shared expert fusion logic for FusedMoE.
|
||||
|
||||
This helper class manages the state and operations related to AITER's
|
||||
shared expert fusion feature, reducing scattered if-branches in the
|
||||
FusedMoE layer.
|
||||
|
||||
Attributes:
|
||||
rocm_aiter_fmoe_enabled: Whether ROCm AITER fused MoE is enabled.
|
||||
aiter_shared_expert_enabled: Whether AITER shared expert fusion
|
||||
is enabled.
|
||||
num_fused_shared_experts: Number of shared experts to fuse (0 if
|
||||
fusion is disabled).
|
||||
"""
|
||||
|
||||
rocm_aiter_fmoe_enabled: bool
|
||||
aiter_shared_expert_enabled: bool
|
||||
num_fused_shared_experts: int
|
||||
|
||||
@classmethod
|
||||
def create(cls, n_shared_experts: int | None) -> "AiterSharedExpertFusion":
|
||||
"""
|
||||
Factory method to create an AiterSharedExpertFusion instance.
|
||||
|
||||
Args:
|
||||
n_shared_experts: Number of shared experts from the model config,
|
||||
or None if not specified.
|
||||
|
||||
Returns:
|
||||
An AiterSharedExpertFusion instance with properly initialized
|
||||
state.
|
||||
|
||||
Raises:
|
||||
ValueError: If n_shared_experts is provided but AITER shared
|
||||
expert fusion is not enabled.
|
||||
"""
|
||||
rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
aiter_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
|
||||
# Compute num_fused_shared_experts
|
||||
num_fused_shared_experts = (
|
||||
n_shared_experts
|
||||
if n_shared_experts is not None and aiter_shared_expert_enabled
|
||||
else 0
|
||||
)
|
||||
|
||||
# Validate configuration
|
||||
if (
|
||||
not aiter_shared_expert_enabled
|
||||
and n_shared_experts is not None
|
||||
and n_shared_experts > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"n_shared_experts is only supported on ROCm aiter when "
|
||||
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
|
||||
)
|
||||
|
||||
return cls(
|
||||
rocm_aiter_fmoe_enabled=rocm_aiter_fmoe_enabled,
|
||||
aiter_shared_expert_enabled=aiter_shared_expert_enabled,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if AITER fused MoE is enabled."""
|
||||
return self.rocm_aiter_fmoe_enabled
|
||||
|
||||
@property
|
||||
def is_shared_expert_fusion_enabled(self) -> bool:
|
||||
"""Check if shared expert fusion is enabled."""
|
||||
return self.aiter_shared_expert_enabled
|
||||
|
||||
@property
|
||||
def has_fused_shared_experts(self) -> bool:
|
||||
"""Check if there are shared experts to fuse."""
|
||||
return self.num_fused_shared_experts > 0
|
||||
|
||||
def get_expert_map_kwargs(self) -> dict:
|
||||
"""
|
||||
Get additional kwargs for determine_expert_map().
|
||||
|
||||
Returns:
|
||||
Dictionary with num_fused_shared_experts and return_expert_mask
|
||||
arguments.
|
||||
"""
|
||||
return {
|
||||
"num_fused_shared_experts": self.num_fused_shared_experts,
|
||||
"return_expert_mask": self.rocm_aiter_fmoe_enabled,
|
||||
}
|
||||
|
||||
def validate_expert_mask(self, expert_mask: torch.Tensor | None) -> None:
|
||||
"""
|
||||
Validate that expert_mask contains only 0s and 1s when AITER is enabled.
|
||||
|
||||
Args:
|
||||
expert_mask: The expert mask tensor to validate.
|
||||
|
||||
Raises:
|
||||
AssertionError: If expert_mask contains values other than 0 and 1.
|
||||
"""
|
||||
if self.rocm_aiter_fmoe_enabled and expert_mask is not None:
|
||||
assert torch.all((expert_mask == 0) | (expert_mask == 1)), (
|
||||
"Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
|
||||
)
|
||||
|
||||
def init_topk_buffers(
|
||||
self,
|
||||
layer: "FusedMoE",
|
||||
vllm_config: VllmConfig,
|
||||
dp_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AITER topK metadata buffers if shared expert fusion is enabled.
|
||||
|
||||
This method also updates layer.local_num_experts to include fused
|
||||
shared experts.
|
||||
|
||||
Args:
|
||||
layer: The FusedMoE layer to initialize buffers for.
|
||||
vllm_config: The vLLM configuration.
|
||||
dp_size: Data parallel size.
|
||||
"""
|
||||
if self.has_fused_shared_experts:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
|
||||
init_aiter_topK_meta_data(
|
||||
n_routed_experts=layer.global_num_experts,
|
||||
n_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=layer.top_k,
|
||||
tp_rank=layer.ep_rank if layer.use_ep else layer.tp_rank,
|
||||
tp_size=layer.ep_size if layer.use_ep else layer.tp_size,
|
||||
shared_experts_score=1.0,
|
||||
max_num_tokens=(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens * dp_size
|
||||
),
|
||||
is_EP=layer.use_ep,
|
||||
)
|
||||
layer.local_num_experts += self.num_fused_shared_experts
|
||||
|
||||
def get_expert_map(
|
||||
self,
|
||||
expert_map: torch.Tensor | None,
|
||||
expert_mask: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Get the appropriate expert map based on AITER state.
|
||||
|
||||
When AITER fused MoE is enabled, returns expert_mask instead of
|
||||
expert_map.
|
||||
|
||||
Args:
|
||||
expert_map: The standard expert map tensor.
|
||||
expert_mask: The AITER-specific expert mask tensor.
|
||||
|
||||
Returns:
|
||||
expert_mask if AITER is enabled, otherwise expert_map.
|
||||
"""
|
||||
return expert_mask if self.rocm_aiter_fmoe_enabled else expert_map
|
||||
@ -12,8 +12,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
@ -26,15 +25,15 @@ from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.aiter_shared_expert_fusion import (
|
||||
AiterSharedExpertFusion,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@ -434,25 +433,8 @@ class FusedMoE(CustomOp):
|
||||
vllm_config.parallel_config.expert_placement_strategy
|
||||
)
|
||||
|
||||
# ROCm aiter shared experts fusion
|
||||
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
self.aiter_fmoe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
|
||||
self.num_fused_shared_experts = (
|
||||
n_shared_experts
|
||||
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
|
||||
else 0
|
||||
)
|
||||
if (
|
||||
not self.aiter_fmoe_shared_expert_enabled
|
||||
and self.num_fused_shared_experts != 0
|
||||
):
|
||||
raise ValueError(
|
||||
"n_shared_experts is only supported on ROCm aiter when "
|
||||
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
|
||||
)
|
||||
# ROCm aiter shared experts fusion - encapsulated in helper class
|
||||
self._aiter_fusion = AiterSharedExpertFusion.create(n_shared_experts)
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
@ -480,8 +462,7 @@ class FusedMoE(CustomOp):
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
**self._aiter_fusion.get_expert_map_kwargs(),
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("_expert_map", expert_map)
|
||||
@ -508,13 +489,11 @@ class FusedMoE(CustomOp):
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=vllm_config, dp_size=dp_size_
|
||||
self._aiter_fusion.init_topk_buffers(
|
||||
layer=self, vllm_config=vllm_config, dp_size=dp_size_
|
||||
)
|
||||
if self.use_ep and self.rocm_aiter_fmoe_enabled:
|
||||
assert self.expert_mask is None or torch.all(
|
||||
(expert_mask == 0) | (expert_mask == 1)
|
||||
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
|
||||
if self.use_ep:
|
||||
self._aiter_fusion.validate_expert_mask(self.expert_mask)
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
@ -838,15 +817,15 @@ class FusedMoE(CustomOp):
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
**self._aiter_fusion.get_expert_map_kwargs(),
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("_expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._maybe_init_expert_routing_tables()
|
||||
if self.aiter_fmoe_shared_expert_enabled:
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
if self._aiter_fusion.is_shared_expert_fusion_enabled:
|
||||
self._aiter_fusion.init_topk_buffers(
|
||||
layer=self,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
)
|
||||
@ -1063,22 +1042,21 @@ class FusedMoE(CustomOp):
|
||||
return expert_id
|
||||
return self._expert_map[expert_id].item()
|
||||
|
||||
def _init_aiter_shared_experts_topK_buffer(
|
||||
self, vllm_config: VllmConfig, dp_size: int
|
||||
):
|
||||
if self.num_fused_shared_experts > 0:
|
||||
init_aiter_topK_meta_data(
|
||||
n_routed_experts=self.global_num_experts,
|
||||
n_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=self.top_k,
|
||||
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
|
||||
tp_size=self.ep_size if self.use_ep else self.tp_size,
|
||||
shared_experts_score=1.0,
|
||||
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
|
||||
* dp_size,
|
||||
is_EP=self.use_ep,
|
||||
)
|
||||
self.local_num_experts += self.num_fused_shared_experts
|
||||
# Backward-compatible properties for AITER shared expert fusion
|
||||
@property
|
||||
def rocm_aiter_fmoe_enabled(self) -> bool:
|
||||
"""Whether ROCm AITER fused MoE is enabled."""
|
||||
return self._aiter_fusion.rocm_aiter_fmoe_enabled
|
||||
|
||||
@property
|
||||
def aiter_fmoe_shared_expert_enabled(self) -> bool:
|
||||
"""Whether AITER shared expert fusion is enabled."""
|
||||
return self._aiter_fusion.aiter_shared_expert_enabled
|
||||
|
||||
@property
|
||||
def num_fused_shared_experts(self) -> int:
|
||||
"""Number of shared experts to fuse (0 if fusion is disabled)."""
|
||||
return self._aiter_fusion.num_fused_shared_experts
|
||||
|
||||
@overload
|
||||
def weight_loader(
|
||||
@ -1583,12 +1561,12 @@ class FusedMoE(CustomOp):
|
||||
elif self.use_grouped_topk and valid_grouping():
|
||||
assert self.topk_group is not None
|
||||
assert self.num_expert_group is not None
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert self.num_fused_shared_experts == 0
|
||||
if self._aiter_fusion.is_enabled:
|
||||
if not self._aiter_fusion.is_shared_expert_fusion_enabled:
|
||||
assert self._aiter_fusion.num_fused_shared_experts == 0
|
||||
grouped_topk_impl = partial(
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
num_fused_shared_experts=self._aiter_fusion.num_fused_shared_experts,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
@ -1735,9 +1713,7 @@ class FusedMoE(CustomOp):
|
||||
|
||||
@property
|
||||
def expert_map(self) -> torch.Tensor | None:
|
||||
return (
|
||||
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
|
||||
)
|
||||
return self._aiter_fusion.get_expert_map(self._expert_map, self.expert_mask)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user