[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#25990)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang 2025-11-19 15:29:06 -06:00 committed by GitHub
parent cdeec2e606
commit 613abb50d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1064 additions and 35 deletions

View File

@ -921,6 +921,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40

View File

@ -0,0 +1,582 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
import torch
from flashinfer import fp4_quantize
from torch.nn import functional as F
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
)
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
)
from vllm.utils.flashinfer import (
scaled_fp4_grouped_quantize,
)
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
FLOAT8_E4M3_MAX = 448.0
FLOAT4_E2M1_MAX = 6.0
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
def generate_balanced_routing(
hidden_states: torch.Tensor, num_experts: int, top_k: int
):
"""
Generate routing weights and topk indices such that every expert is active.
Returns routing_weights, topk_idx
"""
num_tokens, hidden_dim = hidden_states.shape
# num_tokens = batch_size * seq_len
# First, assign at least one token per expert
tokens_per_expert = torch.arange(num_tokens) % num_experts
tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle
# Each token has top_k experts — start with one guaranteed expert
topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long)
topk_idx[:, 0] = tokens_per_expert
# For remaining top_k - 1 experts, pick randomly (allowing repeats)
if top_k > 1:
random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1))
topk_idx[:, 1:] = random_choices
# Normalize routing weights so each token's weights sum to 1
routing_weights = torch.rand(num_tokens, top_k)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# Reshape back if needed
routing_weights = routing_weights.view(num_tokens, top_k)
topk_idx = topk_idx.view(num_tokens, top_k)
return routing_weights, topk_idx
def prepare_inputs(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
num_experts: int,
topk: int,
):
routing_weights, topk_idx = generate_balanced_routing(
router_logits, num_experts, topk
)
masked_m = []
for i in range(num_experts):
mask = topk_idx.view(-1) == i
masked_m.append(mask.sum())
masked_m = torch.tensor(masked_m, dtype=torch.int32)
# Intialize the hidden_states_3d with ones instead of empty to avoid nan
# issue.
hidden_states_3d = torch.ones(
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
)
for i in range(num_experts):
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
return hidden_states_3d, masked_m, topk_idx, routing_weights
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1024),
(224, 1024, 1024),
(224, 1024, 1536),
]
# Reference implementation of torch_moe
def torch_moe(a, w1, w2, score, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
m = w1[i].shape[0]
assert m % 2 == 0
# Note: w1 and w3 are swapped!
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
inter_gs = torch.tensor(1.0).cuda()
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
inter = dequantize_nvfp4_to_dtype(
inter_q,
inter_blockscale,
inter_gs,
dtype=inter.dtype,
device=inter.device,
block_size=16,
).cuda()
out[mask] = inter @ w2[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def grouped_gemm_ref(
hidden_states_expanded: torch.Tensor,
hidden_states_3d: torch.Tensor,
weights: torch.Tensor,
topk_idx: torch.Tensor,
masked_m: torch.Tensor,
B: int,
topk: int,
num_experts: int,
*,
block_size: int = 16,
) -> torch.Tensor:
"""
Computes the reference grouped GEMM (fp4 quantized per-expert loop),
computes flashinfer grouped GEMM (for scale consistency),
and returns ONLY the repacked reference output: out_ref.
Returns:
out_ref: Tensor [num_experts, max_m, n_out]
"""
device_hs = hidden_states_expanded.device
device_w = weights.device
out_dtype = weights.dtype
n_out = weights.shape[1]
# Flattened reference output (B*topk, n_out)
out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
# Per-expert reference compute loop
for i in range(num_experts):
mask = topk_idx.view(-1) == i
if mask.any():
lhs = hidden_states_expanded[mask]
rhs = weights[i]
a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
b_amax = rhs.abs().max().to(torch.float32).to(device_w)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
lhs_in_dtype = dequantize_nvfp4_to_dtype(
lhsq,
lhsq_sf,
a_gs,
dtype=lhs.dtype,
device=device_hs,
block_size=block_size,
)
rhs_in_dtype = dequantize_nvfp4_to_dtype(
rhsq,
rhsq_sf,
b_gs,
dtype=rhs.dtype,
device=device_w,
block_size=block_size,
)
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
# Determine per-expert max_m
max_m_val = int(masked_m.max().item())
# Repack into [num_experts, max_m, n_out]
out_ref = torch.zeros(
(num_experts, max_m_val, n_out),
dtype=out.dtype,
device=out.device,
)
expert_slot = [0] * num_experts
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
slot = expert_slot[expert_id]
if slot < max_m_val:
out_ref[expert_id, slot, :] = out[i]
expert_slot[expert_id] += 1
else:
raise IndexError(
f"Expert {expert_id} exceeded max slots ({max_m_val}). "
"Increase max_m or check masked_m."
)
return out_ref
def flashinfer_cutedsl_grouped_gemm_nt_masked(
hidden_states: torch.Tensor, # 3d
input_global_scale: torch.Tensor, # (l,)
weights: torch.Tensor,
w_global_scale: torch.Tensor, # (l,)
masked_m: torch.Tensor,
):
# hidden_states: [l, m, k]
# weights: [l, n, k]
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m.to(hidden_states.device),
input_global_scale,
)
num_experts, n, k = weights.shape
bq, bq_sf = scaled_fp4_grouped_quantize(
weights,
torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
w_global_scale,
)
out = torch.zeros(
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
)
out = out.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = "bfloat16"
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
1, 1, num_experts
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
cutedsl_gmm_masked(
(aq, aq_sf),
(bq, bq_sf),
out,
masked_m.to(aq.device),
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=alpha,
alpha_dtype=get_cute_dtype(alpha),
)
return out
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
@pytest.mark.parametrize("topk", [1, 2, 4])
@torch.inference_mode()
def test_flashinfer_cutedsl_moe_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
):
torch.manual_seed(42)
device = "cuda"
num_experts = 8
hidden_states = (
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
)
w1 = (
torch.randn(
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
)
/ 10.0
)
w2 = (
torch.randn(
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
)
/ 10.0
)
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(bs, -1, hidden_dim)
.repeat(1, topk, 1)
.reshape(-1, hidden_dim)
)
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
input_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
)
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
a2_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
) # assume intermediate scale is 1.0
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
w1,
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
w1_global_scale,
)
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
w2,
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
w2_global_scale,
)
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
out = torch.empty_like(hidden_states_3d)
# Note: the 1st dim shouldn't be bs
wk = torch.empty(
num_experts,
hidden_states_3d.shape[1],
inter_dim * 2,
dtype=hidden_states_3d.dtype,
device=hidden_states.device,
)
flashinfer_cutedsl_moe_masked(
hidden_states_3d.to(hidden_states.device),
input_global_scale,
w1_fp4.permute(2, 0, 1),
w1_blockscale,
w1_alpha,
w2_fp4.permute(2, 0, 1),
a2_global_scale,
w2_blockscale,
w2_alpha,
masked_m.to(hidden_states.device),
wk,
out,
)
# reference
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
input_global_scale,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
w1_d = torch.empty(
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
)
w2_d = torch.empty(
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
)
for idx in range(0, num_experts):
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
w1[idx], w1_global_scale[idx]
)
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
w2[idx], w2_global_scale[idx]
)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_fp4_sliced,
w1_blockscale_sliced,
w1_global_scale[idx],
dtype=w1.dtype,
device=w1.device,
block_size=16,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_fp4_sliced,
w2_blockscale_sliced,
w2_global_scale[idx],
dtype=w2.dtype,
device=w2.device,
block_size=16,
)
ref_output = torch_moe_nvfp4(
a_in_dtype,
w1_d,
w2_d,
topk,
routing_weights.to(a_in_dtype.device),
topk_idx.to(a_in_dtype.device),
)
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
rows, cols = positions[:, 0], positions[:, 1]
experts = topk_idx[rows, cols]
for i in range(num_experts):
mask = experts == i
if mask.any():
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
r, c = rows[idx], cols[idx]
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
out.device
).unsqueeze(-1)
torch.testing.assert_close(
out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
)
@pytest.mark.parametrize(
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
)
@torch.inference_mode()
def test_grouped_gemm_nt_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
) -> None:
torch.manual_seed(42)
B = bs
D = hidden_dim
N = inter_dim
# CuteDSL group gemm has issue when not all experts are active.
# i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
# see https://github.com/flashinfer-ai/flashinfer/issues/1856
num_experts = bs
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
)
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
a_amax = (
hidden_states_3d.abs()
.amax(dim=(1, 2))
.to(torch.float32)
.to(hidden_states.device)
)
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
)
# reference
out_ref = grouped_gemm_ref(
hidden_states_expanded=hidden_states_expanded,
hidden_states_3d=hidden_states_3d,
weights=weights,
topk_idx=topk_idx,
masked_m=masked_m,
B=B,
topk=topk,
num_experts=num_experts,
)
# Note: just to compare the masked position due to cutedsl may write nan
# into unmasked position.
for i in range(num_experts):
torch.testing.assert_close(
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
atol=1e-1,
rtol=1e-1,
)
if __name__ == "__main__":
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
test_grouped_gemm_nt_masked(16, 128, 512, 4)

View File

@ -157,7 +157,9 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency"
)
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@ -1238,7 +1240,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
"VLLM_FLASHINFER_MOE_BACKEND",
"latency",
["throughput", "latency", "masked_gemm"],
),
# Control the workspace buffer size for the FlashInfer backend.
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(

View File

@ -6,6 +6,7 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@ -27,6 +28,8 @@ logger = init_logger(__name__)
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
logger = init_logger(__name__)
def dequant_fp8(
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
@ -187,16 +190,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
q_dtype = quant_config.quant_dtype
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
logger.info_once(
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
"for ModelOptNvFp4FusedMoE."
)
q_dtype = None
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
quant_config.quant_dtype,
q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
if quant_config.quant_dtype is not None:
if q_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)

View File

@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
logger = init_logger(__name__)
def is_valid_flashinfer_cutedsl_fused_moe(
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
) -> bool:
"""
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
kernel.
"""
if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
logger.debug_once(
"FlashInferCuteDSLExperts disabled: "
"flashinfer_cutedsl_fused_moe not available."
)
return False
# Data type checks
if (
w1.dtype != torch.uint8
or w2.dtype != torch.uint8
or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
):
logger.debug_once(
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
)
return False
return True
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
self.out_dtype = out_dtype
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
def supports_expert_map(self) -> bool:
return False
def supports_chunking(self) -> bool:
# This refers to TP chunking; DP chunking is handled separately.
# TODO(shuw@nvidia.com): Set to False to be consistent with
# batched_deep_gemm_moe
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
output_shape = (local_num_experts, M, K)
workspace2 = (local_num_experts, M, N)
workspace1 = output_shape
return (workspace1, workspace2, output_shape)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, # Not used
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert self.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
# Ensure w1_scale and w2_scale are not None before calling view
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not be None for FlashInferExperts"
)
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
assert hidden_states.ndim == 3
assert self.w1_scale.ndim == 3
assert self.w2_scale.ndim == 3
flashinfer_cutedsl_moe_masked(
hidden_states=hidden_states,
input_global_scale=self.a1_gscale,
w1=w1,
w1_blockscale=self.w1_scale,
w1_alpha=self.g1_alphas,
w2=w2,
a2_global_scale=self.a2_gscale,
w2_blockscale=self.w2_scale,
w2_alpha=self.g2_alphas,
masked_m=expert_num_tokens,
workspace=workspace2,
out=output,
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor,
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
workspace: torch.Tensor,
out: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
workspace (torch.Tensor): For gateup_output
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert input_global_scale.dtype == torch.float32, (
f"input_global_scale must be float32, got {input_global_scale.dtype}"
)
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
)
assert w1_alpha.dtype == torch.float32, (
f"w1_alpha must be float32, got {w1_alpha.dtype}"
)
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
assert a2_global_scale.dtype == torch.float32, (
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
)
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
)
assert w2_alpha.dtype == torch.float32, (
f"w2_alpha must be float32, got {w2_alpha.dtype}"
)
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
num_experts, m, k = hidden_states.shape
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert w1.shape[-1] * 2 == k, (
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
)
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
assert input_global_scale.shape == (num_experts,), (
f"input_global_scale must be (l,), got {input_global_scale.shape}"
)
assert w1_alpha.shape == (num_experts,), (
f"w1_alpha must be (l,), got {w1_alpha.shape}"
)
assert a2_global_scale.shape == (num_experts,), (
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
)
assert w2_alpha.shape == (num_experts,), (
f"w2_alpha must be (l,), got {w2_alpha.shape}"
)
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m,
input_global_scale,
)
workspace = workspace.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
workspace,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]
# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
workspace.permute(2, 0, 1),
masked_m,
a2_global_scale,
)
# Gemm2
out = out.permute(1, 2, 0) # requirement of kernel
flashinfer_cutedsl_grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)
def flashinfer_cutedsl_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
FlashInferCuteDSLExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -1468,7 +1468,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer:
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
)
@ -1746,17 +1749,26 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
workspace=layer.workspace,
)
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
elif self.allow_flashinfer:
assert self.flashinfer_moe_backend in (
FlashinferMoeBackend.CUTLASS,
FlashinferMoeBackend.CUTEDSL,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
else:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
flashinfer_cutedsl_moe_fp4,
)
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
return flashinfer_fn_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,

View File

@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
@ -17,10 +20,14 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
@ -36,6 +43,16 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
)
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability(100)
)
def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
@ -72,15 +89,21 @@ def select_nvfp4_gemm_impl(
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer:
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
return FlashInferCuteDSLExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
)
elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
# native cutlass experts currently don't support DP; TP case won't call this
raise ValueError(

View File

@ -25,6 +25,7 @@ logger = init_logger(__name__)
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
CUTEDSL = "CUTEDSL"
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
@ -273,19 +274,21 @@ def flashinfer_cutlass_moe_fp8(
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
90
):
return FlashinferMoeBackend.CUTLASS
elif flashinfer_moe_backend == "latency":
return FlashinferMoeBackend.TENSORRT_LLM
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
"latency": FlashinferMoeBackend.TENSORRT_LLM,
"masked_gemm": FlashinferMoeBackend.CUTEDSL,
}
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend in backend_map:
return backend_map[flashinfer_moe_backend]
elif current_platform.is_device_capability(90):
return FlashinferMoeBackend.CUTLASS
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}"
f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. "
f"Expected one of {list(backend_map.keys())}."
)

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
@ -32,7 +33,10 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported = cutlass_fp4_supported()
allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available()
allow_flashinfer = cutlass_supported and (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
)
if allow_flashinfer:
_logger.info_once(

View File

@ -114,7 +114,17 @@ flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "cutlass_fused_moe"
)
flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
"flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
)
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
"flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
)
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
"flashinfer", "scaled_fp4_grouped_quantize"
)
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave"
)
@ -166,6 +176,14 @@ def has_flashinfer_moe() -> bool:
)
@functools.cache
def has_flashinfer_cutedsl() -> bool:
"""Return ``True`` if FlashInfer cutedsl module is available."""
return (
has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
)
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
@ -187,6 +205,26 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
return True
@functools.cache
def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
if not has_flashinfer_cutedsl():
return False
# Check if all required functions are available
required_functions = [
("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
("flashinfer", "scaled_fp4_grouped_quantize"),
("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache
def has_nvidia_artifactory() -> bool:
"""Return `True` if NVIDIA's artifactory is accessible.
@ -472,7 +510,10 @@ __all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
"flashinfer_cutedsl_grouped_gemm_nt_masked",
"flashinfer_fp4_quantize",
"silu_and_mul_scaled_nvfp4_experts_quantize",
"scaled_fp4_grouped_quantize",
"nvfp4_block_scale_interleave",
"trtllm_fp4_block_scale_moe",
"autotune",
@ -480,6 +521,7 @@ __all__ = [
"has_flashinfer_comm",
"has_flashinfer_all2all",
"has_flashinfer_cutlass_fused_moe",
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
"has_nvidia_artifactory",
"supports_trtllm_attention",
"can_use_trtllm_attention",