[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc 2025-11-10 09:34:57 -08:00 committed by GitHub
parent b039bfda8f
commit 34553b9d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 78 additions and 30 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Union
import torch
@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
return a_shape, w_shape
# The type of method in top-K routing
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Unspecified
Unspecified = 6.0
@dataclass
class FusedMoEQuantDesc:
"""

View File

@ -3,6 +3,7 @@
import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0,
routing_method_type: int = RoutingMethodType.DeepSeekV3,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert top_k <= 10
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 256
assert global_num_experts <= 256
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(
x.shape[0], top_k, global_num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
tile_tokens_dim=None,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
biased_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
@ -1213,6 +1214,7 @@ class FusedMoE(CustomOp):
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
):
super().__init__()
@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp):
"Only softmax scoring function is supported for non-grouped topk."
)
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,

View File

@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
@ -1222,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (
renormalize and use_grouped_topk and custom_routing_function is None
)
e_score_correction_bias = (
e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
@ -1252,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor,
)
else:

View File

@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
tile_tokens_dim = 8
# from flashinfer import next_positive_power_of_2
# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
# A factor considering tokens are not perfectly balanced among experts.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim

View File

@ -43,6 +43,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)
self.gate = ReplicatedLinear(

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: