mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:16:24 +08:00
Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
583 lines
18 KiB
Python
583 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import pytest
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
if not current_platform.has_device_capability(100):
|
|
pytest.skip(
|
|
reason="Nvfp4 Requires compute capability of 10 or above.",
|
|
allow_module_level=True,
|
|
)
|
|
|
|
import torch
|
|
from flashinfer import fp4_quantize
|
|
from torch.nn import functional as F
|
|
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
|
flashinfer_cutedsl_moe_masked,
|
|
)
|
|
from vllm.utils.flashinfer import (
|
|
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
|
|
)
|
|
from vllm.utils.flashinfer import (
|
|
scaled_fp4_grouped_quantize,
|
|
)
|
|
|
|
kE2M1ToFloat = torch.tensor(
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
|
)
|
|
|
|
FLOAT8_E4M3_MAX = 448.0
|
|
FLOAT4_E2M1_MAX = 6.0
|
|
|
|
|
|
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
|
m_tiles = (m + 128 - 1) // 128
|
|
f = block_size * 4
|
|
k_tiles = (k + f - 1) // f
|
|
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
|
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
|
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
|
return out[0:m, 0:k]
|
|
|
|
|
|
def dequantize_nvfp4_to_dtype(
|
|
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
|
):
|
|
"""Dequantize the fp4 tensor back to high precision."""
|
|
# Two fp4 values are packed into one uint8.
|
|
assert tensor_fp4.dtype == torch.uint8
|
|
m, packed_k = tensor_fp4.shape
|
|
k = packed_k * 2
|
|
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
|
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
|
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
|
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
|
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
|
|
|
# scale the tensor
|
|
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
|
return out.to(dtype=dtype)
|
|
|
|
|
|
def break_fp4_bytes(a, dtype):
|
|
assert a.dtype == torch.uint8
|
|
m, n = a.shape
|
|
|
|
# Vectorized nibble processing
|
|
a_flat = a.flatten()
|
|
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
|
low = a_flat & 0x0F # Lower nibbles
|
|
|
|
# Combine nibbles for batch processing
|
|
combined = torch.stack((low, high), dim=1).flatten()
|
|
|
|
# Vectorized sign and magnitude extraction
|
|
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
|
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
|
|
|
# Device-aware lookup and sign application
|
|
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
|
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
|
|
|
# Reshape to final form
|
|
return values.reshape(m, n * 2).to(dtype=dtype)
|
|
|
|
|
|
def generate_balanced_routing(
|
|
hidden_states: torch.Tensor, num_experts: int, top_k: int
|
|
):
|
|
"""
|
|
Generate routing weights and topk indices such that every expert is active.
|
|
Returns routing_weights, topk_idx
|
|
"""
|
|
|
|
num_tokens, hidden_dim = hidden_states.shape
|
|
# num_tokens = batch_size * seq_len
|
|
|
|
# First, assign at least one token per expert
|
|
tokens_per_expert = torch.arange(num_tokens) % num_experts
|
|
tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle
|
|
|
|
# Each token has top_k experts — start with one guaranteed expert
|
|
topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long)
|
|
topk_idx[:, 0] = tokens_per_expert
|
|
|
|
# For remaining top_k - 1 experts, pick randomly (allowing repeats)
|
|
if top_k > 1:
|
|
random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1))
|
|
topk_idx[:, 1:] = random_choices
|
|
|
|
# Normalize routing weights so each token's weights sum to 1
|
|
routing_weights = torch.rand(num_tokens, top_k)
|
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
|
|
# Reshape back if needed
|
|
routing_weights = routing_weights.view(num_tokens, top_k)
|
|
topk_idx = topk_idx.view(num_tokens, top_k)
|
|
|
|
return routing_weights, topk_idx
|
|
|
|
|
|
def prepare_inputs(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
num_experts: int,
|
|
topk: int,
|
|
):
|
|
routing_weights, topk_idx = generate_balanced_routing(
|
|
router_logits, num_experts, topk
|
|
)
|
|
|
|
masked_m = []
|
|
for i in range(num_experts):
|
|
mask = topk_idx.view(-1) == i
|
|
masked_m.append(mask.sum())
|
|
|
|
masked_m = torch.tensor(masked_m, dtype=torch.int32)
|
|
# Intialize the hidden_states_3d with ones instead of empty to avoid nan
|
|
# issue.
|
|
hidden_states_3d = torch.ones(
|
|
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
|
|
)
|
|
for i in range(num_experts):
|
|
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
|
|
|
|
return hidden_states_3d, masked_m, topk_idx, routing_weights
|
|
|
|
|
|
MNK_FACTORS = [
|
|
(2, 1024, 1024),
|
|
(2, 1024, 1536),
|
|
(2, 3072, 1024),
|
|
(2, 3072, 1536),
|
|
(64, 1024, 1024),
|
|
(64, 1024, 1536),
|
|
(64, 3072, 1024),
|
|
(64, 2048, 1024),
|
|
(224, 1024, 1024),
|
|
(224, 1024, 1536),
|
|
]
|
|
|
|
|
|
# Reference implementation of torch_moe
|
|
def torch_moe(a, w1, w2, score, topk, expert_map):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
if expert_map is not None:
|
|
topk_ids = expert_map[topk_ids]
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
|
0, 1
|
|
)
|
|
return (
|
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
|
).sum(dim=1)
|
|
|
|
|
|
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
m = w1[i].shape[0]
|
|
assert m % 2 == 0
|
|
# Note: w1 and w3 are swapped!
|
|
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
|
|
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
|
|
inter_gs = torch.tensor(1.0).cuda()
|
|
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
|
|
inter = dequantize_nvfp4_to_dtype(
|
|
inter_q,
|
|
inter_blockscale,
|
|
inter_gs,
|
|
dtype=inter.dtype,
|
|
device=inter.device,
|
|
block_size=16,
|
|
).cuda()
|
|
out[mask] = inter @ w2[i].transpose(0, 1)
|
|
return (
|
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
|
).sum(dim=1)
|
|
|
|
|
|
def grouped_gemm_ref(
|
|
hidden_states_expanded: torch.Tensor,
|
|
hidden_states_3d: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
masked_m: torch.Tensor,
|
|
B: int,
|
|
topk: int,
|
|
num_experts: int,
|
|
*,
|
|
block_size: int = 16,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes the reference grouped GEMM (fp4 quantized per-expert loop),
|
|
computes flashinfer grouped GEMM (for scale consistency),
|
|
and returns ONLY the repacked reference output: out_ref.
|
|
|
|
Returns:
|
|
out_ref: Tensor [num_experts, max_m, n_out]
|
|
"""
|
|
device_hs = hidden_states_expanded.device
|
|
device_w = weights.device
|
|
out_dtype = weights.dtype
|
|
n_out = weights.shape[1]
|
|
|
|
# Flattened reference output (B*topk, n_out)
|
|
out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
|
|
|
|
# Per-expert reference compute loop
|
|
for i in range(num_experts):
|
|
mask = topk_idx.view(-1) == i
|
|
if mask.any():
|
|
lhs = hidden_states_expanded[mask]
|
|
rhs = weights[i]
|
|
|
|
a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
|
|
b_amax = rhs.abs().max().to(torch.float32).to(device_w)
|
|
|
|
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
|
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
|
|
|
lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
|
|
rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
|
|
|
|
lhs_in_dtype = dequantize_nvfp4_to_dtype(
|
|
lhsq,
|
|
lhsq_sf,
|
|
a_gs,
|
|
dtype=lhs.dtype,
|
|
device=device_hs,
|
|
block_size=block_size,
|
|
)
|
|
rhs_in_dtype = dequantize_nvfp4_to_dtype(
|
|
rhsq,
|
|
rhsq_sf,
|
|
b_gs,
|
|
dtype=rhs.dtype,
|
|
device=device_w,
|
|
block_size=block_size,
|
|
)
|
|
|
|
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
|
|
|
|
# Determine per-expert max_m
|
|
max_m_val = int(masked_m.max().item())
|
|
|
|
# Repack into [num_experts, max_m, n_out]
|
|
out_ref = torch.zeros(
|
|
(num_experts, max_m_val, n_out),
|
|
dtype=out.dtype,
|
|
device=out.device,
|
|
)
|
|
expert_slot = [0] * num_experts
|
|
|
|
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
|
|
slot = expert_slot[expert_id]
|
|
if slot < max_m_val:
|
|
out_ref[expert_id, slot, :] = out[i]
|
|
expert_slot[expert_id] += 1
|
|
else:
|
|
raise IndexError(
|
|
f"Expert {expert_id} exceeded max slots ({max_m_val}). "
|
|
"Increase max_m or check masked_m."
|
|
)
|
|
|
|
return out_ref
|
|
|
|
|
|
def flashinfer_cutedsl_grouped_gemm_nt_masked(
|
|
hidden_states: torch.Tensor, # 3d
|
|
input_global_scale: torch.Tensor, # (l,)
|
|
weights: torch.Tensor,
|
|
w_global_scale: torch.Tensor, # (l,)
|
|
masked_m: torch.Tensor,
|
|
):
|
|
# hidden_states: [l, m, k]
|
|
# weights: [l, n, k]
|
|
aq, aq_sf = scaled_fp4_grouped_quantize(
|
|
hidden_states,
|
|
masked_m.to(hidden_states.device),
|
|
input_global_scale,
|
|
)
|
|
num_experts, n, k = weights.shape
|
|
bq, bq_sf = scaled_fp4_grouped_quantize(
|
|
weights,
|
|
torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
|
|
w_global_scale,
|
|
)
|
|
|
|
out = torch.zeros(
|
|
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
|
|
)
|
|
out = out.permute(1, 2, 0) # requirement of kernel
|
|
sf_vec_size = 16
|
|
ab_dtype = "float4_e2m1fn"
|
|
sf_dtype = "float8_e4m3fn"
|
|
c_dtype = "bfloat16"
|
|
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
|
|
1, 1, num_experts
|
|
)
|
|
|
|
def get_cute_dtype(input: torch.Tensor) -> str:
|
|
if input.dtype == torch.bfloat16:
|
|
return "bfloat16"
|
|
elif input.dtype == torch.float16:
|
|
return "float16"
|
|
elif input.dtype == torch.float32:
|
|
return "float32"
|
|
else:
|
|
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
|
|
|
cutedsl_gmm_masked(
|
|
(aq, aq_sf),
|
|
(bq, bq_sf),
|
|
out,
|
|
masked_m.to(aq.device),
|
|
ab_dtype=ab_dtype,
|
|
sf_dtype=sf_dtype,
|
|
c_dtype=c_dtype,
|
|
sf_vec_size=sf_vec_size,
|
|
alpha=alpha,
|
|
alpha_dtype=get_cute_dtype(alpha),
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
|
|
@pytest.mark.parametrize("topk", [1, 2, 4])
|
|
@torch.inference_mode()
|
|
def test_flashinfer_cutedsl_moe_masked(
|
|
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
|
):
|
|
torch.manual_seed(42)
|
|
device = "cuda"
|
|
num_experts = 8
|
|
hidden_states = (
|
|
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
|
|
)
|
|
w1 = (
|
|
torch.randn(
|
|
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
|
|
)
|
|
/ 10.0
|
|
)
|
|
w2 = (
|
|
torch.randn(
|
|
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
|
|
)
|
|
/ 10.0
|
|
)
|
|
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
|
|
|
|
hidden_states_expanded = (
|
|
hidden_states.view(bs, -1, hidden_dim)
|
|
.repeat(1, topk, 1)
|
|
.reshape(-1, hidden_dim)
|
|
)
|
|
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
|
|
hidden_states_expanded, router_logits, num_experts, topk
|
|
)
|
|
|
|
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
|
|
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
|
|
input_global_scale = torch.ones(
|
|
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
|
|
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
|
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
|
a2_global_scale = torch.ones(
|
|
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
|
) # assume intermediate scale is 1.0
|
|
|
|
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
|
|
w1,
|
|
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
|
|
w1_global_scale,
|
|
)
|
|
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
|
|
w2,
|
|
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
|
|
w2_global_scale,
|
|
)
|
|
|
|
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
|
|
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
|
|
|
|
out = torch.empty_like(hidden_states_3d)
|
|
# Note: the 1st dim shouldn't be bs
|
|
wk = torch.empty(
|
|
num_experts,
|
|
hidden_states_3d.shape[1],
|
|
inter_dim * 2,
|
|
dtype=hidden_states_3d.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
flashinfer_cutedsl_moe_masked(
|
|
hidden_states_3d.to(hidden_states.device),
|
|
input_global_scale,
|
|
w1_fp4.permute(2, 0, 1),
|
|
w1_blockscale,
|
|
w1_alpha,
|
|
w2_fp4.permute(2, 0, 1),
|
|
a2_global_scale,
|
|
w2_blockscale,
|
|
w2_alpha,
|
|
masked_m.to(hidden_states.device),
|
|
wk,
|
|
out,
|
|
)
|
|
|
|
# reference
|
|
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
|
|
a_in_dtype = dequantize_nvfp4_to_dtype(
|
|
a_fp4,
|
|
a_scale_interleaved,
|
|
input_global_scale,
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
block_size=16,
|
|
)
|
|
w1_d = torch.empty(
|
|
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
|
|
)
|
|
w2_d = torch.empty(
|
|
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
|
|
)
|
|
|
|
for idx in range(0, num_experts):
|
|
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
|
|
w1[idx], w1_global_scale[idx]
|
|
)
|
|
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
|
|
w2[idx], w2_global_scale[idx]
|
|
)
|
|
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
|
w1_fp4_sliced,
|
|
w1_blockscale_sliced,
|
|
w1_global_scale[idx],
|
|
dtype=w1.dtype,
|
|
device=w1.device,
|
|
block_size=16,
|
|
)
|
|
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
|
w2_fp4_sliced,
|
|
w2_blockscale_sliced,
|
|
w2_global_scale[idx],
|
|
dtype=w2.dtype,
|
|
device=w2.device,
|
|
block_size=16,
|
|
)
|
|
|
|
ref_output = torch_moe_nvfp4(
|
|
a_in_dtype,
|
|
w1_d,
|
|
w2_d,
|
|
topk,
|
|
routing_weights.to(a_in_dtype.device),
|
|
topk_idx.to(a_in_dtype.device),
|
|
)
|
|
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
|
|
|
|
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
|
|
rows, cols = positions[:, 0], positions[:, 1]
|
|
experts = topk_idx[rows, cols]
|
|
for i in range(num_experts):
|
|
mask = experts == i
|
|
if mask.any():
|
|
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
|
|
r, c = rows[idx], cols[idx]
|
|
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
|
|
out.device
|
|
).unsqueeze(-1)
|
|
torch.testing.assert_close(
|
|
out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
|
|
)
|
|
@torch.inference_mode()
|
|
def test_grouped_gemm_nt_masked(
|
|
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
|
) -> None:
|
|
torch.manual_seed(42)
|
|
B = bs
|
|
D = hidden_dim
|
|
N = inter_dim
|
|
# CuteDSL group gemm has issue when not all experts are active.
|
|
# i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
|
|
# see https://github.com/flashinfer-ai/flashinfer/issues/1856
|
|
num_experts = bs
|
|
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
|
|
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
|
|
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
|
|
|
|
hidden_states_expanded = (
|
|
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
)
|
|
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
|
|
hidden_states_expanded, router_logits, num_experts, topk
|
|
)
|
|
|
|
a_amax = (
|
|
hidden_states_3d.abs()
|
|
.amax(dim=(1, 2))
|
|
.to(torch.float32)
|
|
.to(hidden_states.device)
|
|
)
|
|
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
|
|
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
|
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
|
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
|
|
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
|
|
)
|
|
# reference
|
|
out_ref = grouped_gemm_ref(
|
|
hidden_states_expanded=hidden_states_expanded,
|
|
hidden_states_3d=hidden_states_3d,
|
|
weights=weights,
|
|
topk_idx=topk_idx,
|
|
masked_m=masked_m,
|
|
B=B,
|
|
topk=topk,
|
|
num_experts=num_experts,
|
|
)
|
|
# Note: just to compare the masked position due to cutedsl may write nan
|
|
# into unmasked position.
|
|
for i in range(num_experts):
|
|
torch.testing.assert_close(
|
|
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
|
|
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
|
|
atol=1e-1,
|
|
rtol=1e-1,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
|
|
test_grouped_gemm_nt_masked(16, 128, 512, 4)
|