vllm/tests/kernels/quantization/test_flashinfer_scaled_mm.py
nvjullin f66673a39d
[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)
Signed-off-by: Julien Lin <jullin@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-08-26 06:54:04 -07:00

74 lines
2.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
if not current_platform.has_device_capability(100):
pytest.skip(
reason=
"Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("autotune", [False, True])
@torch.inference_mode()
def test_flashinfer_fp8_gemm(
dtype: torch.dtype,
shape: tuple[int, int, int],
use_bias: bool,
seed: int,
device: str,
autotune: bool,
) -> None:
current_platform.seed_everything(seed)
m, n, k = shape
a = torch.randn((m, k), dtype=dtype, device=device)
b = torch.randn((n, k), dtype=dtype, device=device) / k
a_fp8, a_scale = ops.scaled_fp8_quant(a)
b_fp8, b_scale = ops.scaled_fp8_quant(b)
expected_out = torch.mm(
a_scale * a_fp8.to(dtype=torch.float32),
b_scale * b_fp8.to(dtype=torch.float32).t(),
).to(dtype=dtype)
if use_bias:
bias = torch.randn((n, ), dtype=dtype, device=device)
expected_out = expected_out + bias
else:
bias = None
import flashinfer
with flashinfer.autotune(autotune):
out = flashinfer_scaled_fp8_mm(
a_fp8,
b_fp8.t(),
a_scale,
b_scale,
dtype,
bias=bias,
)
torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)