diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbc3caafcf2f0..a7bd64b1c65e9 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from enum import IntEnum from typing import Optional, Union import torch @@ -91,6 +92,26 @@ def _quant_flags_to_group_shape( return a_shape, w_shape +# The type of method in top-K routing +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups + # -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # RenormalizeNaive: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK: TopK (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6.0 + + @dataclass class FusedMoEQuantDesc: """ diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index f21fe16c5108e..51e06ac54f497 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8( w2_weight_scale_inv: torch.Tensor, global_num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: float = 1.0, + routing_method_type: int = RoutingMethodType.DeepSeekV3, + routed_scaling: float | None = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 + assert top_k <= 10 assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 256 - assert global_num_experts <= 256 + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! @@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim( - x.shape[0], top_k, global_num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method + tile_tokens_dim=None, + routing_method_type=routing_method_type, use_shuffled_weight=False, ) @@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( expert_offset: int, local_num_experts: int, block_shape: list[int], + routing_method_type: int, routed_scaling: float = 1.0, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 45b0f50a79973..f86a93e300033 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton @@ -1213,6 +1214,7 @@ class FusedMoE(CustomOp): zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, + routing_method_type: int | None = None, ): super().__init__() @@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp): "Only softmax scoring function is supported for non-grouped topk." ) + # ToDo: Better logic to determine the routing method type + if routing_method_type is not None: + self.routing_method_type = routing_method_type + else: + if scoring_func == "sigmoid": + if self.use_grouped_topk: + self.routing_method_type = RoutingMethodType.DeepSeekV3 + elif self.top_k == 1: + self.routing_method_type = RoutingMethodType.Llama4 + elif self.scoring_func == "softmax": + self.routing_method_type = ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + else: + self.routing_method_type = RoutingMethodType.TopK + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e4e1cbff712f5..f5fc750baaea7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import ( ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe @@ -1222,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) + if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert ( - renormalize and use_grouped_topk and custom_routing_function is None - ) e_score_correction_bias = ( e_score_correction_bias.to(x.dtype) if e_score_correction_bias is not None else None ) + routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, @@ -1252,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, + routing_method_type=routing_method_type, routed_scaling=routed_scaling_factor, ) else: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 50ea049c3d5a1..e49d374f154d8 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. tile_tokens_dim = 8 - # from flashinfer import next_positive_power_of_2 - - # # Guess tokens per expert assuming perfect expert distribution first. - # num_tokens_per_expert = (num_tokens * top_k) // num_experts - # # And pad the number to the next power of 2. - # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # # kernel. - # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + # A factor considering tokens are not perfectly balanced among experts. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-max_tile_tokens_dim tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e6772bb7082..d57b82cb02273 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,6 +43,7 @@ from vllm.distributed import ( from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 55bbad7a8b275..aa7de5aa5f29c 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import ( fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) @@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: