mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#25990)
Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
cdeec2e606
commit
613abb50d5
@ -921,6 +921,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
|
||||
|
||||
- label: Blackwell Fusion and Compile Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
|
||||
582
tests/kernels/moe/test_cutedsl_moe.py
Normal file
582
tests/kernels/moe/test_cutedsl_moe.py
Normal file
@ -0,0 +1,582 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import torch
|
||||
from flashinfer import fp4_quantize
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
scaled_fp4_grouped_quantize,
|
||||
)
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
FLOAT8_E4M3_MAX = 448.0
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
||||
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def generate_balanced_routing(
|
||||
hidden_states: torch.Tensor, num_experts: int, top_k: int
|
||||
):
|
||||
"""
|
||||
Generate routing weights and topk indices such that every expert is active.
|
||||
Returns routing_weights, topk_idx
|
||||
"""
|
||||
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
# num_tokens = batch_size * seq_len
|
||||
|
||||
# First, assign at least one token per expert
|
||||
tokens_per_expert = torch.arange(num_tokens) % num_experts
|
||||
tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle
|
||||
|
||||
# Each token has top_k experts — start with one guaranteed expert
|
||||
topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long)
|
||||
topk_idx[:, 0] = tokens_per_expert
|
||||
|
||||
# For remaining top_k - 1 experts, pick randomly (allowing repeats)
|
||||
if top_k > 1:
|
||||
random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1))
|
||||
topk_idx[:, 1:] = random_choices
|
||||
|
||||
# Normalize routing weights so each token's weights sum to 1
|
||||
routing_weights = torch.rand(num_tokens, top_k)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Reshape back if needed
|
||||
routing_weights = routing_weights.view(num_tokens, top_k)
|
||||
topk_idx = topk_idx.view(num_tokens, top_k)
|
||||
|
||||
return routing_weights, topk_idx
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
):
|
||||
routing_weights, topk_idx = generate_balanced_routing(
|
||||
router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
masked_m = []
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
masked_m.append(mask.sum())
|
||||
|
||||
masked_m = torch.tensor(masked_m, dtype=torch.int32)
|
||||
# Intialize the hidden_states_3d with ones instead of empty to avoid nan
|
||||
# issue.
|
||||
hidden_states_3d = torch.ones(
|
||||
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
|
||||
)
|
||||
for i in range(num_experts):
|
||||
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
|
||||
|
||||
return hidden_states_3d, masked_m, topk_idx, routing_weights
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1024),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
# Reference implementation of torch_moe
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
m = w1[i].shape[0]
|
||||
assert m % 2 == 0
|
||||
# Note: w1 and w3 are swapped!
|
||||
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
|
||||
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
|
||||
inter_gs = torch.tensor(1.0).cuda()
|
||||
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
|
||||
inter = dequantize_nvfp4_to_dtype(
|
||||
inter_q,
|
||||
inter_blockscale,
|
||||
inter_gs,
|
||||
dtype=inter.dtype,
|
||||
device=inter.device,
|
||||
block_size=16,
|
||||
).cuda()
|
||||
out[mask] = inter @ w2[i].transpose(0, 1)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def grouped_gemm_ref(
|
||||
hidden_states_expanded: torch.Tensor,
|
||||
hidden_states_3d: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
B: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
*,
|
||||
block_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the reference grouped GEMM (fp4 quantized per-expert loop),
|
||||
computes flashinfer grouped GEMM (for scale consistency),
|
||||
and returns ONLY the repacked reference output: out_ref.
|
||||
|
||||
Returns:
|
||||
out_ref: Tensor [num_experts, max_m, n_out]
|
||||
"""
|
||||
device_hs = hidden_states_expanded.device
|
||||
device_w = weights.device
|
||||
out_dtype = weights.dtype
|
||||
n_out = weights.shape[1]
|
||||
|
||||
# Flattened reference output (B*topk, n_out)
|
||||
out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
|
||||
|
||||
# Per-expert reference compute loop
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
if mask.any():
|
||||
lhs = hidden_states_expanded[mask]
|
||||
rhs = weights[i]
|
||||
|
||||
a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
|
||||
b_amax = rhs.abs().max().to(torch.float32).to(device_w)
|
||||
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
|
||||
lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
|
||||
rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
|
||||
|
||||
lhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
lhsq,
|
||||
lhsq_sf,
|
||||
a_gs,
|
||||
dtype=lhs.dtype,
|
||||
device=device_hs,
|
||||
block_size=block_size,
|
||||
)
|
||||
rhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
rhsq,
|
||||
rhsq_sf,
|
||||
b_gs,
|
||||
dtype=rhs.dtype,
|
||||
device=device_w,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
|
||||
|
||||
# Determine per-expert max_m
|
||||
max_m_val = int(masked_m.max().item())
|
||||
|
||||
# Repack into [num_experts, max_m, n_out]
|
||||
out_ref = torch.zeros(
|
||||
(num_experts, max_m_val, n_out),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
expert_slot = [0] * num_experts
|
||||
|
||||
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
|
||||
slot = expert_slot[expert_id]
|
||||
if slot < max_m_val:
|
||||
out_ref[expert_id, slot, :] = out[i]
|
||||
expert_slot[expert_id] += 1
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Expert {expert_id} exceeded max slots ({max_m_val}). "
|
||||
"Increase max_m or check masked_m."
|
||||
)
|
||||
|
||||
return out_ref
|
||||
|
||||
|
||||
def flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states: torch.Tensor, # 3d
|
||||
input_global_scale: torch.Tensor, # (l,)
|
||||
weights: torch.Tensor,
|
||||
w_global_scale: torch.Tensor, # (l,)
|
||||
masked_m: torch.Tensor,
|
||||
):
|
||||
# hidden_states: [l, m, k]
|
||||
# weights: [l, n, k]
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
)
|
||||
num_experts, n, k = weights.shape
|
||||
bq, bq_sf = scaled_fp4_grouped_quantize(
|
||||
weights,
|
||||
torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
|
||||
w_global_scale,
|
||||
)
|
||||
|
||||
out = torch.zeros(
|
||||
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
|
||||
)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
c_dtype = "bfloat16"
|
||||
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
|
||||
1, 1, num_experts
|
||||
)
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
cutedsl_gmm_masked(
|
||||
(aq, aq_sf),
|
||||
(bq, bq_sf),
|
||||
out,
|
||||
masked_m.to(aq.device),
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=alpha,
|
||||
alpha_dtype=get_cute_dtype(alpha),
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
|
||||
@pytest.mark.parametrize("topk", [1, 2, 4])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_cutedsl_moe_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_experts = 8
|
||||
hidden_states = (
|
||||
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
|
||||
)
|
||||
w1 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(bs, -1, hidden_dim)
|
||||
.repeat(1, topk, 1)
|
||||
.reshape(-1, hidden_dim)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
|
||||
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
|
||||
input_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
a2_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
) # assume intermediate scale is 1.0
|
||||
|
||||
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
|
||||
w1,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
|
||||
w1_global_scale,
|
||||
)
|
||||
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
|
||||
w2,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
|
||||
w2_global_scale,
|
||||
)
|
||||
|
||||
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
|
||||
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
|
||||
|
||||
out = torch.empty_like(hidden_states_3d)
|
||||
# Note: the 1st dim shouldn't be bs
|
||||
wk = torch.empty(
|
||||
num_experts,
|
||||
hidden_states_3d.shape[1],
|
||||
inter_dim * 2,
|
||||
dtype=hidden_states_3d.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states_3d.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
w1_fp4.permute(2, 0, 1),
|
||||
w1_blockscale,
|
||||
w1_alpha,
|
||||
w2_fp4.permute(2, 0, 1),
|
||||
a2_global_scale,
|
||||
w2_blockscale,
|
||||
w2_alpha,
|
||||
masked_m.to(hidden_states.device),
|
||||
wk,
|
||||
out,
|
||||
)
|
||||
|
||||
# reference
|
||||
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
input_global_scale,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
block_size=16,
|
||||
)
|
||||
w1_d = torch.empty(
|
||||
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
|
||||
)
|
||||
w2_d = torch.empty(
|
||||
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
|
||||
)
|
||||
|
||||
for idx in range(0, num_experts):
|
||||
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
|
||||
w1[idx], w1_global_scale[idx]
|
||||
)
|
||||
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
|
||||
w2[idx], w2_global_scale[idx]
|
||||
)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_fp4_sliced,
|
||||
w1_blockscale_sliced,
|
||||
w1_global_scale[idx],
|
||||
dtype=w1.dtype,
|
||||
device=w1.device,
|
||||
block_size=16,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_fp4_sliced,
|
||||
w2_blockscale_sliced,
|
||||
w2_global_scale[idx],
|
||||
dtype=w2.dtype,
|
||||
device=w2.device,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
ref_output = torch_moe_nvfp4(
|
||||
a_in_dtype,
|
||||
w1_d,
|
||||
w2_d,
|
||||
topk,
|
||||
routing_weights.to(a_in_dtype.device),
|
||||
topk_idx.to(a_in_dtype.device),
|
||||
)
|
||||
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
|
||||
|
||||
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
|
||||
rows, cols = positions[:, 0], positions[:, 1]
|
||||
experts = topk_idx[rows, cols]
|
||||
for i in range(num_experts):
|
||||
mask = experts == i
|
||||
if mask.any():
|
||||
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
|
||||
r, c = rows[idx], cols[idx]
|
||||
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
|
||||
out.device
|
||||
).unsqueeze(-1)
|
||||
torch.testing.assert_close(
|
||||
out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_grouped_gemm_nt_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
) -> None:
|
||||
torch.manual_seed(42)
|
||||
B = bs
|
||||
D = hidden_dim
|
||||
N = inter_dim
|
||||
# CuteDSL group gemm has issue when not all experts are active.
|
||||
# i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
|
||||
# see https://github.com/flashinfer-ai/flashinfer/issues/1856
|
||||
num_experts = bs
|
||||
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
|
||||
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
|
||||
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
a_amax = (
|
||||
hidden_states_3d.abs()
|
||||
.amax(dim=(1, 2))
|
||||
.to(torch.float32)
|
||||
.to(hidden_states.device)
|
||||
)
|
||||
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
|
||||
)
|
||||
# reference
|
||||
out_ref = grouped_gemm_ref(
|
||||
hidden_states_expanded=hidden_states_expanded,
|
||||
hidden_states_3d=hidden_states_3d,
|
||||
weights=weights,
|
||||
topk_idx=topk_idx,
|
||||
masked_m=masked_m,
|
||||
B=B,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
# Note: just to compare the masked position due to cutedsl may write nan
|
||||
# into unmasked position.
|
||||
for i in range(num_experts):
|
||||
torch.testing.assert_close(
|
||||
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
|
||||
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
|
||||
atol=1e-1,
|
||||
rtol=1e-1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
|
||||
test_grouped_gemm_nt_masked(16, 128, 512, 4)
|
||||
@ -157,7 +157,9 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
|
||||
"latency"
|
||||
)
|
||||
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
@ -1238,7 +1240,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "latency":
|
||||
# Uses TensorRT-LLM kernels optimized for low-latency inference.
|
||||
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
|
||||
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
|
||||
"VLLM_FLASHINFER_MOE_BACKEND",
|
||||
"latency",
|
||||
["throughput", "latency", "masked_gemm"],
|
||||
),
|
||||
# Control the workspace buffer size for the FlashInfer backend.
|
||||
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
|
||||
|
||||
@ -6,6 +6,7 @@ import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@ -27,6 +28,8 @@ logger = init_logger(__name__)
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def dequant_fp8(
|
||||
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
|
||||
@ -187,16 +190,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# TODO (varun): Optimization - Use a batched version of quant
|
||||
x = x.view((-1, hidden_dim))
|
||||
q_dtype = quant_config.quant_dtype
|
||||
|
||||
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
|
||||
logger.info_once(
|
||||
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
|
||||
"for ModelOptNvFp4FusedMoE."
|
||||
)
|
||||
q_dtype = None
|
||||
|
||||
x, x_scales = moe_kernel_quantize_input(
|
||||
x,
|
||||
quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
q_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
)
|
||||
x = x.view((num_experts, -1, hidden_dim))
|
||||
|
||||
if quant_config.quant_dtype is not None:
|
||||
if q_dtype is not None:
|
||||
assert x_scales is not None
|
||||
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
||||
|
||||
|
||||
346
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
Normal file
346
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
Normal file
@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
scaled_fp4_grouped_quantize,
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_flashinfer_cutedsl_fused_moe(
|
||||
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
|
||||
kernel.
|
||||
"""
|
||||
if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
|
||||
logger.debug_once(
|
||||
"FlashInferCuteDSLExperts disabled: "
|
||||
"flashinfer_cutedsl_fused_moe not available."
|
||||
)
|
||||
return False
|
||||
# Data type checks
|
||||
if (
|
||||
w1.dtype != torch.uint8
|
||||
or w2.dtype != torch.uint8
|
||||
or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
|
||||
):
|
||||
logger.debug_once(
|
||||
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
|
||||
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
|
||||
f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
# This refers to TP chunking; DP chunking is handled separately.
|
||||
# TODO(shuw@nvidia.com): Set to False to be consistent with
|
||||
# batched_deep_gemm_moe
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
output_shape = (local_num_experts, M, K)
|
||||
workspace2 = (local_num_experts, M, N)
|
||||
workspace1 = output_shape
|
||||
return (workspace1, workspace2, output_shape)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None, # Not used
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert self.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not be None for FlashInferExperts"
|
||||
)
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.w1_scale.ndim == 3
|
||||
assert self.w2_scale.ndim == 3
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=hidden_states,
|
||||
input_global_scale=self.a1_gscale,
|
||||
w1=w1,
|
||||
w1_blockscale=self.w1_scale,
|
||||
w1_alpha=self.g1_alphas,
|
||||
w2=w2,
|
||||
a2_global_scale=self.a2_gscale,
|
||||
w2_blockscale=self.w2_scale,
|
||||
w2_alpha=self.g2_alphas,
|
||||
masked_m=expert_num_tokens,
|
||||
workspace=workspace2,
|
||||
out=output,
|
||||
)
|
||||
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_masked(
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alpha,
|
||||
w2: torch.Tensor,
|
||||
a2_global_scale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
kernels.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): [num_experts, m, k], bf16
|
||||
input_global_scale (torch.Tensor): (l,)
|
||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w1_alpha (torch.Tensor): (l,)
|
||||
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
||||
a2_global_scale (torch.Tensor): (l,)
|
||||
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w2_alpha (torch.Tensor): (l,)
|
||||
masked_m (torch.Tensor): Masked dimension indices
|
||||
workspace (torch.Tensor): For gateup_output
|
||||
|
||||
Notes:
|
||||
- Assumes max(masked_m) <= m.
|
||||
"""
|
||||
|
||||
# === Assertions on dtypes ===
|
||||
assert input_global_scale.dtype == torch.float32, (
|
||||
f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||
)
|
||||
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
|
||||
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
||||
)
|
||||
assert w1_alpha.dtype == torch.float32, (
|
||||
f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
||||
)
|
||||
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
|
||||
assert a2_global_scale.dtype == torch.float32, (
|
||||
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
||||
)
|
||||
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
||||
)
|
||||
assert w2_alpha.dtype == torch.float32, (
|
||||
f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
||||
)
|
||||
|
||||
# === Assertions on shapes ===
|
||||
n = w2.shape[-1] * 2 # intermediate dimension
|
||||
num_experts, m, k = hidden_states.shape
|
||||
|
||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||
assert w1.shape[-1] * 2 == k, (
|
||||
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
||||
)
|
||||
assert w2.shape[-2:] == (
|
||||
k,
|
||||
n // 2,
|
||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
|
||||
|
||||
assert input_global_scale.shape == (num_experts,), (
|
||||
f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||
)
|
||||
assert w1_alpha.shape == (num_experts,), (
|
||||
f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||
)
|
||||
assert a2_global_scale.shape == (num_experts,), (
|
||||
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
||||
)
|
||||
assert w2_alpha.shape == (num_experts,), (
|
||||
f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||
)
|
||||
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m,
|
||||
input_global_scale,
|
||||
)
|
||||
|
||||
workspace = workspace.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
||||
assert aq.dtype == torch.uint8
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
|
||||
c_dtype = get_cute_dtype(hidden_states)
|
||||
|
||||
# Gemm1
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(aq, aq_sf),
|
||||
(w1.permute(1, 2, 0), w1_blockscale),
|
||||
workspace,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w1_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w1_alpha),
|
||||
) # in logical [m, n, l]
|
||||
|
||||
# SILU and quantization
|
||||
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
|
||||
workspace.permute(2, 0, 1),
|
||||
masked_m,
|
||||
a2_global_scale,
|
||||
)
|
||||
|
||||
# Gemm2
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(diq, diq_sf),
|
||||
(w2.permute(1, 2, 0), w2_blockscale),
|
||||
out,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
) # in logical [m, k, l]
|
||||
out = out.permute(2, 0, 1)
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_fp4(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
|
||||
FlashInferCuteDSLExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@ -1468,7 +1468,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if self.allow_flashinfer:
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||
)
|
||||
@ -1746,17 +1749,26 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
elif (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
elif self.allow_flashinfer:
|
||||
assert self.flashinfer_moe_backend in (
|
||||
FlashinferMoeBackend.CUTLASS,
|
||||
FlashinferMoeBackend.CUTEDSL,
|
||||
)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
|
||||
flashinfer_cutedsl_moe_fp4,
|
||||
)
|
||||
|
||||
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
return flashinfer_fn_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
@ -17,10 +20,14 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"is_flashinfer_fp4_cutedsl_moe_available",
|
||||
"reorder_w1w3_to_w3w1",
|
||||
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
|
||||
]
|
||||
@ -36,6 +43,16 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
|
||||
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -72,15 +89,21 @@ def select_nvfp4_gemm_impl(
|
||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||
|
||||
if allow_flashinfer:
|
||||
return FlashInferExperts(
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
|
||||
return FlashInferCuteDSLExperts(
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
|
||||
return FlashInferExperts(
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
|
||||
# native cutlass experts currently don't support DP; TP case won't call this
|
||||
raise ValueError(
|
||||
|
||||
@ -25,6 +25,7 @@ logger = init_logger(__name__)
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
CUTEDSL = "CUTEDSL"
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
@ -273,19 +274,21 @@ def flashinfer_cutlass_moe_fp8(
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
|
||||
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
|
||||
90
|
||||
):
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
return FlashinferMoeBackend.TENSORRT_LLM
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
"latency": FlashinferMoeBackend.TENSORRT_LLM,
|
||||
"masked_gemm": FlashinferMoeBackend.CUTEDSL,
|
||||
}
|
||||
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend in backend_map:
|
||||
return backend_map[flashinfer_moe_backend]
|
||||
elif current_platform.is_device_capability(90):
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}"
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. "
|
||||
f"Expected one of {list(backend_map.keys())}."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_flashinfer_fp4_cutedsl_moe_available,
|
||||
is_flashinfer_fp4_cutlass_moe_available,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
@ -32,7 +33,10 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
||||
"""Detect platform support for NV-FP4 fused-MoE path"""
|
||||
cutlass_supported = cutlass_fp4_supported()
|
||||
|
||||
allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available()
|
||||
allow_flashinfer = cutlass_supported and (
|
||||
is_flashinfer_fp4_cutlass_moe_available()
|
||||
or is_flashinfer_fp4_cutedsl_moe_available()
|
||||
)
|
||||
|
||||
if allow_flashinfer:
|
||||
_logger.info_once(
|
||||
|
||||
@ -114,7 +114,17 @@ flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "cutlass_fused_moe"
|
||||
)
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
|
||||
"flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
|
||||
)
|
||||
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
|
||||
"flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
|
||||
)
|
||||
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
|
||||
"flashinfer", "scaled_fp4_grouped_quantize"
|
||||
)
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer", "nvfp4_block_scale_interleave"
|
||||
)
|
||||
@ -166,6 +176,14 @@ def has_flashinfer_moe() -> bool:
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutedsl() -> bool:
|
||||
"""Return ``True`` if FlashInfer cutedsl module is available."""
|
||||
return (
|
||||
has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||
@ -187,6 +205,26 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
|
||||
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer_cutedsl():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
|
||||
("flashinfer", "scaled_fp4_grouped_quantize"),
|
||||
("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_nvidia_artifactory() -> bool:
|
||||
"""Return `True` if NVIDIA's artifactory is accessible.
|
||||
@ -472,7 +510,10 @@ __all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"flashinfer_fp4_quantize",
|
||||
"silu_and_mul_scaled_nvfp4_experts_quantize",
|
||||
"scaled_fp4_grouped_quantize",
|
||||
"nvfp4_block_scale_interleave",
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
@ -480,6 +521,7 @@ __all__ = [
|
||||
"has_flashinfer_comm",
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user