mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 23:05:01 +08:00
341 lines
9.5 KiB
Python
341 lines
9.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
|
|
"""
|
|
|
|
import random
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
pack_rows,
|
|
quantize_weights,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.scalar_type import ScalarType, scalar_types
|
|
|
|
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
|
|
|
|
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
|
|
|
|
|
def cutlass_quantize(
|
|
atype: torch.dtype,
|
|
w: torch.Tensor,
|
|
wtype: ScalarType,
|
|
stype: torch.dtype | None,
|
|
group_size: int | None,
|
|
zero_points: bool = False,
|
|
):
|
|
"""
|
|
Quantize weights into W4 and compute reference dequantized weights.
|
|
|
|
Encoding/reordering of weights and packing of scales is deferred
|
|
until after all experts are combined.
|
|
"""
|
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
|
|
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
|
w, wtype, group_size=group_size, zero_points=zero_points
|
|
)
|
|
|
|
# Since scales are later cast to fp8, recompute w_ref in atype here.
|
|
w_ref = (
|
|
w_q.to(torch.float32)
|
|
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
|
).to(atype)
|
|
|
|
# Bit mask prevents sign extension of int4 when packing.
|
|
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
|
# Make weights row-major (N, K).
|
|
w_q = w_q.t().contiguous()
|
|
|
|
return w_ref, w_q, w_s.to(atype), w_zp
|
|
|
|
|
|
def cutlass_preprocess(
|
|
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
|
|
):
|
|
"""
|
|
Reorder/encode expert weights and pack scales.
|
|
|
|
Returns:
|
|
w_q_packed: Packed/encoded int4 weights for all experts.
|
|
w_s_packed: Packed fp8 scales for all experts.
|
|
packed_layout: Layout/stride metadata for grouped GEMM.
|
|
"""
|
|
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
|
|
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
|
|
torch.stack(w_q_experts)
|
|
) # expects dim 3
|
|
return w_q_packed, w_s_packed, packed_layout
|
|
|
|
|
|
GROUP_SIZE = 128
|
|
# (num_experts, N, K)
|
|
TEST_SHAPES = [
|
|
(8, 512, 2048),
|
|
(8, 2048, 2048),
|
|
(64, 512, 1024),
|
|
(64, 2048, 2048),
|
|
(4, 2048, 768),
|
|
(8, 768, 2048),
|
|
(64, 1536, 2048),
|
|
(128, 8192, 4096), # test overflow int32
|
|
]
|
|
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
|
|
|
|
|
|
@dataclass
|
|
class MoETestSetup:
|
|
num_experts: int
|
|
K: int
|
|
N: int
|
|
Ms: list[int]
|
|
M_full: int
|
|
a: torch.Tensor
|
|
a_ref: torch.Tensor
|
|
a_strides: torch.Tensor
|
|
out: torch.Tensor
|
|
c_strides: torch.Tensor
|
|
per_tok_scales: torch.Tensor
|
|
per_chan_scales: torch.Tensor
|
|
w_refs: list[torch.Tensor]
|
|
w_q_packed: torch.Tensor
|
|
w_s_packed: torch.Tensor
|
|
problem_sizes: torch.Tensor
|
|
expert_offsets: torch.Tensor
|
|
b_strides: torch.Tensor
|
|
group_scale_strides: torch.Tensor
|
|
|
|
|
|
def make_moe_test_setup(
|
|
num_experts: int,
|
|
K: int,
|
|
N: int,
|
|
*,
|
|
alignment: int = ALIGNMENT,
|
|
max_blocks: int = 64,
|
|
device: str = "cuda",
|
|
random_zero: bool = False,
|
|
) -> MoETestSetup:
|
|
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
|
|
|
|
assert K % GROUP_SIZE == 0
|
|
# Token counts per expert (multiples of `alignment`).
|
|
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
|
|
|
|
# set random experts to 0 tokens
|
|
if random_zero and num_experts > 1:
|
|
num_zero = max(1, num_experts // 8)
|
|
zero_indices = random.sample(range(num_experts), k=num_zero)
|
|
for idx in zero_indices:
|
|
Ms[idx] = 0
|
|
|
|
M_full = sum(Ms)
|
|
assert M_full > 0
|
|
|
|
# Activations.
|
|
a = to_fp8(torch.randn((M_full, K), device=device))
|
|
a_ref = a.to(torch.float32)
|
|
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
|
|
|
|
# Output buffer.
|
|
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
|
|
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
|
|
|
|
# Channel/token scales.
|
|
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
|
|
per_chan_scales = torch.randn(
|
|
(num_experts, N, 1), dtype=torch.float32, device=device
|
|
)
|
|
|
|
# Expert weights and scales.
|
|
wtype = scalar_types.int4
|
|
atype = stype = torch.float8_e4m3fn
|
|
w_refs, w_qs, w_ss = [], [], []
|
|
for _ in range(num_experts):
|
|
b = to_fp8(torch.randn((K, N), device=device))
|
|
w_ref, w_q, w_s, _ = cutlass_quantize(
|
|
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
|
|
)
|
|
w_refs.append(w_ref)
|
|
w_qs.append(w_q)
|
|
w_ss.append(w_s)
|
|
|
|
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
|
|
|
|
problem_sizes = torch.tensor(
|
|
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
|
|
)
|
|
|
|
expert_offsets = torch.cat(
|
|
[
|
|
torch.tensor([0], dtype=torch.int64),
|
|
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
|
|
]
|
|
).to(device=device)
|
|
|
|
# B strides and group scale strides.
|
|
b_strides = packed_layout
|
|
group_scale_strides = torch.zeros(
|
|
(num_experts, 2), dtype=torch.int64, device=device
|
|
)
|
|
group_scale_strides[:, 0] = N
|
|
|
|
return MoETestSetup(
|
|
num_experts=num_experts,
|
|
K=K,
|
|
N=N,
|
|
Ms=Ms,
|
|
M_full=M_full,
|
|
a=a,
|
|
a_ref=a_ref,
|
|
a_strides=a_strides,
|
|
out=out,
|
|
c_strides=c_strides,
|
|
per_tok_scales=per_tok_scales,
|
|
per_chan_scales=per_chan_scales,
|
|
w_refs=w_refs,
|
|
w_q_packed=w_q_packed,
|
|
w_s_packed=w_s_packed,
|
|
problem_sizes=problem_sizes,
|
|
expert_offsets=expert_offsets,
|
|
b_strides=b_strides,
|
|
group_scale_strides=group_scale_strides,
|
|
)
|
|
|
|
|
|
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
|
|
"""Compute reference output using torch._scaled_mm per expert."""
|
|
out_ref = torch.empty_like(setup.out)
|
|
|
|
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
|
|
starts = setup.expert_offsets.cpu().tolist()
|
|
|
|
for i in range(setup.num_experts):
|
|
start, end = starts[i], ends[i]
|
|
if start == end:
|
|
continue
|
|
|
|
out_ref_i = torch._scaled_mm(
|
|
setup.a_ref[start:end].to(torch.float8_e4m3fn),
|
|
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
|
|
setup.per_tok_scales[start:end], # (M, 1)
|
|
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=True,
|
|
)
|
|
out_ref[start:end] = out_ref_i
|
|
|
|
return out_ref
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not IS_SUPPORTED_BY_GPU,
|
|
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
|
)
|
|
@pytest.mark.parametrize("shape", TEST_SHAPES)
|
|
@pytest.mark.parametrize("random_zero", [True, False])
|
|
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
|
|
num_experts, N, K = shape
|
|
current_platform.seed_everything(42)
|
|
setup = make_moe_test_setup(
|
|
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
|
|
)
|
|
|
|
ops.cutlass_w4a8_moe_mm(
|
|
setup.out,
|
|
setup.a,
|
|
setup.w_q_packed,
|
|
setup.per_tok_scales,
|
|
setup.per_chan_scales,
|
|
setup.w_s_packed,
|
|
GROUP_SIZE,
|
|
setup.expert_offsets,
|
|
setup.problem_sizes,
|
|
setup.a_strides,
|
|
setup.b_strides,
|
|
setup.c_strides,
|
|
setup.group_scale_strides,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
out_ref = compute_moe_reference_output(setup)
|
|
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
|
|
class W4A8MoELayer(torch.nn.Module):
|
|
"""
|
|
Minimal wrapper module to test cuda graphs
|
|
"""
|
|
|
|
def __init__(self, setup: MoETestSetup):
|
|
super().__init__()
|
|
self.setup = setup
|
|
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
s = self.setup
|
|
ops.cutlass_w4a8_moe_mm(
|
|
s.out,
|
|
a,
|
|
s.w_q_packed,
|
|
s.per_tok_scales,
|
|
s.per_chan_scales,
|
|
s.w_s_packed,
|
|
GROUP_SIZE,
|
|
s.expert_offsets,
|
|
s.problem_sizes,
|
|
s.a_strides,
|
|
s.b_strides,
|
|
s.c_strides,
|
|
s.group_scale_strides,
|
|
)
|
|
return s.out
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not IS_SUPPORTED_BY_GPU,
|
|
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
|
)
|
|
def test_cutlass_w4a8_moe_mm_cuda_graph():
|
|
current_platform.seed_everything(42)
|
|
# Fixed config for CUDA graph test (single parameter point).
|
|
num_experts = 8
|
|
K = 512
|
|
N = 2048
|
|
|
|
setup = make_moe_test_setup(
|
|
num_experts=num_experts,
|
|
K=K,
|
|
N=N,
|
|
max_blocks=32,
|
|
)
|
|
|
|
# Construct model that calls the grouped GEMM kernel.
|
|
model = W4A8MoELayer(setup)
|
|
|
|
# Build reference output once.
|
|
out_ref = compute_moe_reference_output(setup)
|
|
|
|
# Capture and run the model in a CUDA graph.
|
|
a_static = setup.a.clone() # static input tensor for graph replay
|
|
|
|
stream = torch.cuda.Stream()
|
|
with torch.cuda.stream(stream):
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
out_static = model(a_static)
|
|
|
|
out_static.zero_()
|
|
g.replay()
|
|
|
|
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)
|