mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 06:44:30 +08:00
[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
parent
b039bfda8f
commit
34553b9d27
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
|
||||
return a_shape, w_shape
|
||||
|
||||
|
||||
# The type of method in top-K routing
|
||||
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
|
||||
class RoutingMethodType(IntEnum):
|
||||
# Default: Softmax -> TopK
|
||||
Default = (0,)
|
||||
# Renormalize: TopK -> Softmax
|
||||
Renormalize = (1,)
|
||||
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
|
||||
# -> Top8 experts from the Top4 groups
|
||||
DeepSeekV3 = (2,)
|
||||
# Llama4: Top1 -> Sigmoid
|
||||
Llama4 = (3,)
|
||||
# RenormalizeNaive: Softmax -> TopK -> Renormalize
|
||||
RenormalizeNaive = (4,)
|
||||
# TopK: TopK (no softmax)
|
||||
TopK = (5,)
|
||||
# Unspecified
|
||||
Unspecified = 6.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
calculate_tile_tokens_dim,
|
||||
@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0,
|
||||
routing_method_type: int = RoutingMethodType.DeepSeekV3,
|
||||
routed_scaling: float | None = 1.0,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 8
|
||||
assert topk_group <= 4
|
||||
assert global_num_experts > num_expert_group
|
||||
assert global_num_experts % num_expert_group == 0
|
||||
assert top_k <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 256
|
||||
assert global_num_experts <= 256
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(
|
||||
x.shape[0], top_k, global_num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
biased_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
||||
@ -1213,6 +1214,7 @@ class FusedMoE(CustomOp):
|
||||
zero_expert_type: str | None = None,
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
n_shared_experts: int | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp):
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
)
|
||||
|
||||
# ToDo: Better logic to determine the routing method type
|
||||
if routing_method_type is not None:
|
||||
self.routing_method_type = routing_method_type
|
||||
else:
|
||||
if scoring_func == "sigmoid":
|
||||
if self.use_grouped_topk:
|
||||
self.routing_method_type = RoutingMethodType.DeepSeekV3
|
||||
elif self.top_k == 1:
|
||||
self.routing_method_type = RoutingMethodType.Llama4
|
||||
elif self.scoring_func == "softmax":
|
||||
self.routing_method_type = (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
)
|
||||
else:
|
||||
self.routing_method_type = RoutingMethodType.TopK
|
||||
|
||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
@ -1222,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
)
|
||||
assert scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
assert (
|
||||
renormalize and use_grouped_topk and custom_routing_function is None
|
||||
)
|
||||
e_score_correction_bias = (
|
||||
e_score_correction_bias.to(x.dtype)
|
||||
if e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_logits=router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
@ -1252,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=routing_method_type,
|
||||
routed_scaling=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
from flashinfer import next_positive_power_of_2
|
||||
|
||||
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
|
||||
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
|
||||
# with the necessary kernels is released.
|
||||
tile_tokens_dim = 8
|
||||
|
||||
# from flashinfer import next_positive_power_of_2
|
||||
|
||||
# # Guess tokens per expert assuming perfect expert distribution first.
|
||||
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# # And pad the number to the next power of 2.
|
||||
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
||||
# # kernel.
|
||||
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
# A factor considering tokens are not perfectly balanced among experts.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.distributed import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
routing_method_type=RoutingMethodType.Renormalize,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
GemmaRMSNorm as Qwen3NextRMSNorm,
|
||||
)
|
||||
@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
routing_method_type=RoutingMethodType.Renormalize,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user