mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:14:58 +08:00
291 lines
8.4 KiB
Python
291 lines
8.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for the CUTLASS W4A8 kernel.
|
|
|
|
Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
pack_rows,
|
|
quantize_weights,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.scalar_type import ScalarType, scalar_types
|
|
|
|
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
|
# unit tests to a common utility function. Currently the use of
|
|
# `is_quant_method_supported` conflates kernels with quantization methods
|
|
# an assumption which is breaking down as quantizations methods can have
|
|
# have kernels and some kernels support multiple quantization methods.
|
|
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
|
|
|
MNK_SHAPES = [
|
|
(1, 128, 128),
|
|
(1, 512, 1024),
|
|
(1, 4096, 4096),
|
|
(1, 8192, 28672),
|
|
(13, 8192, 4096),
|
|
(26, 4096, 8192),
|
|
(64, 4096, 4096),
|
|
(64, 8192, 28672),
|
|
(257, 128, 4096),
|
|
(257, 4096, 4096),
|
|
(1024, 4096, 8192),
|
|
(1024, 8192, 4096),
|
|
]
|
|
|
|
# TODO(czhu): get supported schedules from fn
|
|
SCHEDULES = [
|
|
"128x16_1x1x1",
|
|
"256x16_1x1x1",
|
|
"128x32_1x1x1",
|
|
"256x32_1x1x1",
|
|
"128x64_1x1x1",
|
|
"256x64_1x1x1",
|
|
"128x128_1x1x1",
|
|
"256x128_1x1x1",
|
|
"128x256_1x1x1",
|
|
"128x256_2x1x1",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class TypeConfig:
|
|
act_type: torch.dtype
|
|
weight_type: ScalarType
|
|
output_type: torch.dtype | None
|
|
group_scale_type: torch.dtype | None
|
|
channel_scale_type: torch.dtype | None
|
|
token_scale_type: torch.dtype | None
|
|
|
|
|
|
@dataclass
|
|
class Tensors:
|
|
w_ref: torch.Tensor
|
|
a_ref: torch.Tensor
|
|
a: torch.Tensor
|
|
w_q: torch.Tensor
|
|
w_g_s: torch.Tensor
|
|
w_ch_s: torch.Tensor
|
|
w_tok_s: torch.Tensor
|
|
|
|
|
|
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
|
# Ch Scales Type, Tok Scales Type)
|
|
TestTypeTuple = tuple[
|
|
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
|
]
|
|
TEST_TYPES = [
|
|
*(
|
|
TypeConfig(
|
|
act_type=torch.float8_e4m3fn,
|
|
weight_type=w_type,
|
|
output_type=o_type,
|
|
group_scale_type=torch.float8_e4m3fn,
|
|
channel_scale_type=torch.float32,
|
|
token_scale_type=torch.float32,
|
|
)
|
|
for w_type in [scalar_types.int4]
|
|
# TODO(czhu): fp16 out type
|
|
for o_type in [torch.bfloat16]
|
|
),
|
|
]
|
|
|
|
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
|
# unit tests to a common utility function. Currently the use of
|
|
# `is_quant_method_supported` conflates kernels with quantization methods
|
|
# an assumption which is breaking down as quantizations methods can have
|
|
# have kernels and some kernels support multiple quantization methods.
|
|
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
|
|
|
|
|
# For testing quantized linear kernels
|
|
def to_fp8(tensor: torch.Tensor):
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
|
|
|
|
|
def cutlass_quantize_and_pack(
|
|
atype: torch.dtype,
|
|
w: torch.Tensor,
|
|
wtype: ScalarType,
|
|
stype: torch.dtype | None,
|
|
group_size: int | None,
|
|
zero_points: bool = False,
|
|
):
|
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
|
|
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
|
w, wtype, group_size=group_size, zero_points=zero_points
|
|
)
|
|
|
|
# since scales are cast to fp8, we need to compute w_ref this way
|
|
w_ref = (
|
|
(w_q).to(torch.float32)
|
|
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
|
).to(atype)
|
|
|
|
# bit mask prevents sign extending int4 when packing
|
|
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
|
w_q = w_q.t().contiguous().t() # convert to col major
|
|
|
|
w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q)
|
|
w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype))
|
|
|
|
return w_ref, w_q_packed, w_s_packed, w_zp
|
|
|
|
|
|
def create_test_tensors(
|
|
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
|
) -> Tensors:
|
|
m, n, k = shape
|
|
|
|
print(
|
|
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
|
|
)
|
|
|
|
a = to_fp8(torch.randn((m, k), device="cuda"))
|
|
w = to_fp8(torch.randn((k, n), device="cuda"))
|
|
|
|
if types.group_scale_type is not None:
|
|
w = w.to(types.group_scale_type)
|
|
if w.dtype.itemsize == 1:
|
|
w = w.to(torch.float16)
|
|
|
|
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
|
a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
|
|
)
|
|
|
|
a_ref = a.to(torch.float32)
|
|
w_ref = w_ref.to(torch.float32)
|
|
|
|
# for the practical use case we need per-tok scales for fp8 activations
|
|
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
|
# weights are already per-group quantized, use placeholder here
|
|
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
|
|
|
|
return Tensors(
|
|
w_ref=w_ref,
|
|
a_ref=a_ref,
|
|
a=a,
|
|
w_q=w_q_packed,
|
|
w_g_s=w_s,
|
|
w_ch_s=w_ch_s,
|
|
w_tok_s=w_tok_s,
|
|
)
|
|
|
|
|
|
def mm_test_helper(
|
|
types: TypeConfig,
|
|
tensors: Tensors,
|
|
group_size: int | None = None,
|
|
schedule: str | None = None,
|
|
):
|
|
# CUTLASS upstream uses fp8 with fastaccum as reference
|
|
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
|
output_ref = torch._scaled_mm(
|
|
tensors.a_ref.to(types.act_type),
|
|
tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major
|
|
tensors.w_tok_s.unsqueeze(1),
|
|
tensors.w_ch_s.unsqueeze(0),
|
|
out_dtype=types.output_type,
|
|
use_fast_accum=True,
|
|
)
|
|
|
|
output = ops.cutlass_w4a8_mm(
|
|
a=tensors.a,
|
|
b_q=tensors.w_q,
|
|
b_group_scales=tensors.w_g_s,
|
|
b_group_size=group_size,
|
|
b_channel_scales=tensors.w_ch_s,
|
|
a_token_scales=tensors.w_tok_s,
|
|
)
|
|
|
|
print(output)
|
|
print(output_ref)
|
|
|
|
torch.testing.assert_close(
|
|
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
|
)
|
|
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
|
@pytest.mark.parametrize("types", TEST_TYPES)
|
|
@pytest.mark.parametrize("schedule", SCHEDULES)
|
|
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
|
group_sizes = [128]
|
|
for group_size in group_sizes:
|
|
tensors = create_test_tensors(shape, types, group_size)
|
|
mm_test_helper(types, tensors, group_size, schedule)
|
|
|
|
|
|
# Test to make sure cuda graphs work
|
|
class W4A8Layer(torch.nn.Module):
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
self.kwargs = kwargs
|
|
|
|
def forward(self, a):
|
|
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
|
)
|
|
def test_w4a8_cuda_graph():
|
|
m, n, k = 512, 4096, 4096
|
|
|
|
a = to_fp8(torch.randn((m, k), device="cuda"))
|
|
b = to_fp8(torch.randn((k, n), device="cuda"))
|
|
|
|
wtype = scalar_types.int4
|
|
stype = torch.float8_e4m3fn
|
|
group_size = 128
|
|
zero_points = False
|
|
|
|
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
|
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
|
|
)
|
|
|
|
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
|
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
|
|
|
|
# Construct a trivial model with a single layer that calls the kernel
|
|
model = W4A8Layer(
|
|
b_q=w_q_packed,
|
|
b_group_scales=w_s,
|
|
b_group_size=group_size,
|
|
b_channel_scales=w_ch_s,
|
|
a_token_scales=w_tok_s,
|
|
)
|
|
|
|
output_ref = torch._scaled_mm(
|
|
a,
|
|
w_ref.to(a.dtype).t().contiguous().t(), # col major
|
|
w_tok_s.unsqueeze(1),
|
|
w_ch_s.unsqueeze(0),
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=True,
|
|
)
|
|
|
|
# Run the model with a cuda graph
|
|
stream = torch.cuda.Stream()
|
|
with torch.cuda.stream(stream):
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
output = model(a)
|
|
|
|
output.zero_()
|
|
g.replay()
|
|
|
|
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)
|