[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc 2025-11-10 09:34:57 -08:00 committed by GitHub
parent b039bfda8f
commit 34553b9d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 78 additions and 30 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
return a_shape, w_shape return a_shape, w_shape
# The type of method in top-K routing
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Unspecified
Unspecified = 6.0
@dataclass @dataclass
class FusedMoEQuantDesc: class FusedMoEQuantDesc:
""" """

View File

@ -3,6 +3,7 @@
import torch import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim, calculate_tile_tokens_dim,
@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
w2_weight_scale_inv: torch.Tensor, w2_weight_scale_inv: torch.Tensor,
global_num_experts: int, global_num_experts: int,
top_k: int, top_k: int,
num_expert_group: int, num_expert_group: int | None,
topk_group: int, topk_group: int | None,
intermediate_size: int, intermediate_size: int,
expert_offset: int, expert_offset: int,
local_num_experts: int, local_num_experts: int,
block_shape: list[int], block_shape: list[int],
routed_scaling: float = 1.0, routing_method_type: int = RoutingMethodType.DeepSeekV3,
routed_scaling: float | None = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts assert top_k <= global_num_experts
assert top_k <= 8 assert top_k <= 10
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0 assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128] assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 256 # Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 256 assert global_num_experts <= 512
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed! # NOTE: scales of hidden states have to be transposed!
@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset, local_expert_offset=expert_offset,
local_num_experts=local_num_experts, local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling, routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim( tile_tokens_dim=None,
x.shape[0], top_k, global_num_experts routing_method_type=routing_method_type,
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False, use_shuffled_weight=False,
) )
@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
expert_offset: int, expert_offset: int,
local_num_experts: int, local_num_experts: int,
block_shape: list[int], block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0, routed_scaling: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(x) return torch.empty_like(x)

View File

@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
biased_moe_quant_config, biased_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
@ -1213,6 +1214,7 @@ class FusedMoE(CustomOp):
zero_expert_type: str | None = None, zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None, n_shared_experts: int | None = None,
routing_method_type: int | None = None,
): ):
super().__init__() super().__init__()
@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp):
"Only softmax scoring function is supported for non-grouped topk." "Only softmax scoring function is supported for non-grouped topk."
) )
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK
self.moe_config: FusedMoEConfig = FusedMoEConfig( self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,

View File

@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
@ -1222,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert activation == "silu", ( assert activation == "silu", (
f"Expected 'silu' activation but got {activation}" f"Expected 'silu' activation but got {activation}"
) )
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (
renormalize and use_grouped_topk and custom_routing_function is None
)
e_score_correction_bias = ( e_score_correction_bias = (
e_score_correction_bias.to(x.dtype) e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None if e_score_correction_bias is not None
else None else None
) )
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32), routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias, routing_bias=e_score_correction_bias,
x=x, x=x,
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
@ -1252,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor, routed_scaling=routed_scaling_factor,
) )
else: else:

View File

@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer # TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released. # with the necessary kernels is released.
tile_tokens_dim = 8 tile_tokens_dim = 8
# from flashinfer import next_positive_power_of_2 # A factor considering tokens are not perfectly balanced among experts.
imbalance_factor = 1.3
# # Guess tokens per expert assuming perfect expert distribution first. # Calculate the number of tokens per expert
# num_tokens_per_expert = (num_tokens * top_k) // num_experts # assuming perfect distribution.
# # And pad the number to the next power of 2. num_tokens_per_expert = (num_tokens * top_k) // num_experts
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) # Apply the imbalance factor.
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# # kernel. # And pad the number to the next power of 2.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim return tile_tokens_dim

View File

@ -43,6 +43,7 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule, fused_recurrent_gated_delta_rule,
) )
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm, GemmaRMSNorm as Qwen3NextRMSNorm,
) )
@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: