[NVIDIA] Add SM100 Flashinfer MoE per tensor scale fp8 backend (#21458)

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
amirkl94 2025-07-31 16:00:01 +03:00 committed by GitHub
parent 5daffe7cf6
commit 207b750e19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 269 additions and 49 deletions

View File

@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
@ -1065,22 +1067,6 @@ direct_register_custom_op(
)
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
return 1 << (x - 1).bit_length()
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
@ -1128,7 +1114,7 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
@ -1164,6 +1150,97 @@ direct_register_custom_op(
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, input_scale = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
gemm2_weights: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float = 1.0,
use_routing_scales_on_input: bool = False,
tile_tokens_dim: int = 8,
routing_method_type: int = 0) -> torch.Tensor:
pass
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,

View File

@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@ -53,11 +56,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return x.reshape(-1, 2, x.shape[-2] // 2,
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
@ -695,11 +693,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif self.flashinfer_moe_enabled:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = _swap_w13_to_w31(
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = swap_w13_to_w31(
layer.w13_weight_scale_inv.data)
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
if not self.block_quant:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
else:
w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
@ -998,12 +998,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
expert_map=expert_map)
elif self.flashinfer_moe_enabled:
# Currently only work with DS models
assert self.block_quant
assert activation == 'silu'
assert scoring_func == 'sigmoid'
if self.block_quant:
assert (renormalize and use_grouped_topk
and scoring_func == 'sigmoid'
and custom_routing_function is None)
assert activation == "silu"
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
@ -1022,6 +1022,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape=self.quant_config.weight_block_size,
routed_scaling=1.0,
)
else:
assert (not renormalize
and custom_routing_function is not None)
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
return self.fused_experts(
hidden_states=x,

View File

@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
@ -34,6 +37,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer_moe
logger = init_logger(__name__)
@ -267,6 +271,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_enabled = False
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
logger.info_once(
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
self.flashinfer_moe_enabled = True
def create_weights(
self,
@ -410,6 +419,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
requires_grad=False)
if self.flashinfer_moe_enabled:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
def apply(
self,
layer: torch.nn.Module,
@ -436,6 +450,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
if self.flashinfer_moe_enabled:
assert activation == 'silu'
assert not renormalize
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,

View File

@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return x.reshape(-1, 2, x.shape[-2] // 2,
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor):
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
epilogue_tile_m = 128
num_experts = gemm1_weights.shape[0]
hidden_size = gemm1_weights.shape[-1]
intermediate_size = gemm1_weights.shape[1] // 2
# Reorder rows of W1 for fused gated activation
gemm1_weights_fp8_interleaved = []
for i in range(num_experts):
gemm1_weights_fp8_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights[i]))
# Stack weights and scales for all experts
gemm1_weights_fp8_interleaved = torch.stack(
gemm1_weights_fp8_interleaved).reshape(num_experts,
2 * intermediate_size,
hidden_size)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp8_shuffled = []
gemm2_weights_fp8_shuffled = []
for i in range(num_experts):
gemm1_weights_fp8_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp8_interleaved[i].view(torch.uint8),
epilogue_tile_m))
gemm2_weights_fp8_shuffled.append(
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8),
epilogue_tile_m))
# Stack weights for all experts
gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view(
torch.float8_e4m3fn)
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
torch.float8_e4m3fn)
def apply_flashinfer_per_tensor_scale_fp8(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
"FusedMoE flashinfer kernels are only supported for Llama4"
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale,
activation_scale=layer.w2_input_scale,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=RoutingMethodType.Llama4,
)

View File

@ -66,6 +66,8 @@ def _lazy_import_wrapper(module_name: str,
# Create lazy wrappers for each function
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe")
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")