mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 01:27:04 +08:00
[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
parent
b039bfda8f
commit
34553b9d27
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
|
|||||||
return a_shape, w_shape
|
return a_shape, w_shape
|
||||||
|
|
||||||
|
|
||||||
|
# The type of method in top-K routing
|
||||||
|
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
|
||||||
|
class RoutingMethodType(IntEnum):
|
||||||
|
# Default: Softmax -> TopK
|
||||||
|
Default = (0,)
|
||||||
|
# Renormalize: TopK -> Softmax
|
||||||
|
Renormalize = (1,)
|
||||||
|
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
|
||||||
|
# -> Top8 experts from the Top4 groups
|
||||||
|
DeepSeekV3 = (2,)
|
||||||
|
# Llama4: Top1 -> Sigmoid
|
||||||
|
Llama4 = (3,)
|
||||||
|
# RenormalizeNaive: Softmax -> TopK -> Renormalize
|
||||||
|
RenormalizeNaive = (4,)
|
||||||
|
# TopK: TopK (no softmax)
|
||||||
|
TopK = (5,)
|
||||||
|
# Unspecified
|
||||||
|
Unspecified = 6.0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FusedMoEQuantDesc:
|
class FusedMoEQuantDesc:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
calculate_tile_tokens_dim,
|
calculate_tile_tokens_dim,
|
||||||
@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
|
|||||||
w2_weight_scale_inv: torch.Tensor,
|
w2_weight_scale_inv: torch.Tensor,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
num_expert_group: int,
|
num_expert_group: int | None,
|
||||||
topk_group: int,
|
topk_group: int | None,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
expert_offset: int,
|
expert_offset: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
block_shape: list[int],
|
block_shape: list[int],
|
||||||
routed_scaling: float = 1.0,
|
routing_method_type: int = RoutingMethodType.DeepSeekV3,
|
||||||
|
routed_scaling: float | None = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||||
|
|
||||||
|
topk_group = topk_group if topk_group is not None else 0
|
||||||
assert top_k <= global_num_experts
|
assert top_k <= global_num_experts
|
||||||
assert top_k <= 8
|
assert top_k <= 10
|
||||||
assert topk_group <= 4
|
|
||||||
assert global_num_experts > num_expert_group
|
|
||||||
assert global_num_experts % num_expert_group == 0
|
|
||||||
assert global_num_experts % 4 == 0
|
assert global_num_experts % 4 == 0
|
||||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
|
||||||
assert block_shape == [128, 128]
|
assert block_shape == [128, 128]
|
||||||
# Routing kernel expects #experts <= #threads 256
|
# Routing kernel expects #experts <= #threads 512
|
||||||
assert global_num_experts <= 256
|
assert global_num_experts <= 512
|
||||||
|
|
||||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||||
# NOTE: scales of hidden states have to be transposed!
|
# NOTE: scales of hidden states have to be transposed!
|
||||||
@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
|||||||
local_expert_offset=expert_offset,
|
local_expert_offset=expert_offset,
|
||||||
local_num_experts=local_num_experts,
|
local_num_experts=local_num_experts,
|
||||||
routed_scaling_factor=routed_scaling,
|
routed_scaling_factor=routed_scaling,
|
||||||
tile_tokens_dim=calculate_tile_tokens_dim(
|
tile_tokens_dim=None,
|
||||||
x.shape[0], top_k, global_num_experts
|
routing_method_type=routing_method_type,
|
||||||
),
|
|
||||||
routing_method_type=2, # DeepSeek-styled routing method
|
|
||||||
use_shuffled_weight=False,
|
use_shuffled_weight=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
|||||||
expert_offset: int,
|
expert_offset: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
block_shape: list[int],
|
block_shape: list[int],
|
||||||
|
routing_method_type: int,
|
||||||
routed_scaling: float = 1.0,
|
routed_scaling: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEParallelConfig,
|
FusedMoEParallelConfig,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
|
RoutingMethodType,
|
||||||
biased_moe_quant_config,
|
biased_moe_quant_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
||||||
@ -1213,6 +1214,7 @@ class FusedMoE(CustomOp):
|
|||||||
zero_expert_type: str | None = None,
|
zero_expert_type: str | None = None,
|
||||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||||
n_shared_experts: int | None = None,
|
n_shared_experts: int | None = None,
|
||||||
|
routing_method_type: int | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp):
|
|||||||
"Only softmax scoring function is supported for non-grouped topk."
|
"Only softmax scoring function is supported for non-grouped topk."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ToDo: Better logic to determine the routing method type
|
||||||
|
if routing_method_type is not None:
|
||||||
|
self.routing_method_type = routing_method_type
|
||||||
|
else:
|
||||||
|
if scoring_func == "sigmoid":
|
||||||
|
if self.use_grouped_topk:
|
||||||
|
self.routing_method_type = RoutingMethodType.DeepSeekV3
|
||||||
|
elif self.top_k == 1:
|
||||||
|
self.routing_method_type = RoutingMethodType.Llama4
|
||||||
|
elif self.scoring_func == "softmax":
|
||||||
|
self.routing_method_type = (
|
||||||
|
RoutingMethodType.Renormalize
|
||||||
|
if not self.renormalize
|
||||||
|
else RoutingMethodType.RenormalizeNaive
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.routing_method_type = RoutingMethodType.TopK
|
||||||
|
|
||||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
|
RoutingMethodType,
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||||
@ -1222,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"Expected 'silu' activation but got {activation}"
|
f"Expected 'silu' activation but got {activation}"
|
||||||
)
|
)
|
||||||
assert scoring_func == "sigmoid", (
|
|
||||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
|
||||||
)
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||||
|
|
||||||
assert (
|
|
||||||
renormalize and use_grouped_topk and custom_routing_function is None
|
|
||||||
)
|
|
||||||
e_score_correction_bias = (
|
e_score_correction_bias = (
|
||||||
e_score_correction_bias.to(x.dtype)
|
e_score_correction_bias.to(x.dtype)
|
||||||
if e_score_correction_bias is not None
|
if e_score_correction_bias is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
routing_method_type = layer.routing_method_type
|
||||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||||
routing_logits=router_logits.to(torch.float32),
|
routing_logits=router_logits.to(torch.float32)
|
||||||
|
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||||
|
else router_logits,
|
||||||
routing_bias=e_score_correction_bias,
|
routing_bias=e_score_correction_bias,
|
||||||
x=x,
|
x=x,
|
||||||
w13_weight=layer.w13_weight,
|
w13_weight=layer.w13_weight,
|
||||||
@ -1252,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
local_num_experts=layer.local_num_experts,
|
local_num_experts=layer.local_num_experts,
|
||||||
block_shape=self.weight_block_size,
|
block_shape=self.weight_block_size,
|
||||||
|
routing_method_type=routing_method_type,
|
||||||
routed_scaling=routed_scaling_factor,
|
routed_scaling=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||||
|
from flashinfer import next_positive_power_of_2
|
||||||
|
|
||||||
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
|
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
|
||||||
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
|
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
|
||||||
# with the necessary kernels is released.
|
# with the necessary kernels is released.
|
||||||
tile_tokens_dim = 8
|
tile_tokens_dim = 8
|
||||||
|
|
||||||
# from flashinfer import next_positive_power_of_2
|
# A factor considering tokens are not perfectly balanced among experts.
|
||||||
|
imbalance_factor = 1.3
|
||||||
# # Guess tokens per expert assuming perfect expert distribution first.
|
# Calculate the number of tokens per expert
|
||||||
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
# assuming perfect distribution.
|
||||||
# # And pad the number to the next power of 2.
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||||
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
# Apply the imbalance factor.
|
||||||
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||||
# # kernel.
|
# And pad the number to the next power of 2.
|
||||||
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||||
|
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
|
||||||
|
# as it's the range supported by the kernel.
|
||||||
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||||
|
|
||||||
return tile_tokens_dim
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from vllm.distributed import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
is_sequence_parallel=self.is_sequence_parallel,
|
is_sequence_parallel=self.is_sequence_parallel,
|
||||||
|
routing_method_type=RoutingMethodType.Renormalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
self.gate = ReplicatedLinear(
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
|
|||||||
fused_recurrent_gated_delta_rule,
|
fused_recurrent_gated_delta_rule,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||||
from vllm.model_executor.layers.layernorm import (
|
from vllm.model_executor.layers.layernorm import (
|
||||||
GemmaRMSNorm as Qwen3NextRMSNorm,
|
GemmaRMSNorm as Qwen3NextRMSNorm,
|
||||||
)
|
)
|
||||||
@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
|||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
is_sequence_parallel=self.is_sequence_parallel,
|
is_sequence_parallel=self.is_sequence_parallel,
|
||||||
|
routing_method_type=RoutingMethodType.Renormalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user