mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
|
|
|
|
if not current_platform.has_device_capability(100):
|
|
pytest.skip(
|
|
reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
|
|
allow_module_level=True,
|
|
)
|
|
|
|
DTYPES = [torch.float16, torch.bfloat16]
|
|
# m, n, k
|
|
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
|
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
|
SHAPES.extend(PAD_SHAPES)
|
|
|
|
SEEDS = [42]
|
|
CUDA_DEVICES = ["cuda:0"]
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("shape", SHAPES)
|
|
@pytest.mark.parametrize("use_bias", [True, False])
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("autotune", [False, True])
|
|
@torch.inference_mode()
|
|
def test_flashinfer_fp8_gemm(
|
|
dtype: torch.dtype,
|
|
shape: tuple[int, int, int],
|
|
use_bias: bool,
|
|
seed: int,
|
|
device: str,
|
|
autotune: bool,
|
|
) -> None:
|
|
current_platform.seed_everything(seed)
|
|
m, n, k = shape
|
|
a = torch.randn((m, k), dtype=dtype, device=device)
|
|
b = torch.randn((n, k), dtype=dtype, device=device) / k
|
|
|
|
a_fp8, a_scale = ops.scaled_fp8_quant(a)
|
|
b_fp8, b_scale = ops.scaled_fp8_quant(b)
|
|
|
|
expected_out = torch.mm(
|
|
a_scale * a_fp8.to(dtype=torch.float32),
|
|
b_scale * b_fp8.to(dtype=torch.float32).t(),
|
|
).to(dtype=dtype)
|
|
|
|
if use_bias:
|
|
bias = torch.randn((n,), dtype=dtype, device=device)
|
|
expected_out = expected_out + bias
|
|
else:
|
|
bias = None
|
|
|
|
import flashinfer
|
|
|
|
with flashinfer.autotune(autotune):
|
|
out = flashinfer_scaled_fp8_mm(
|
|
a_fp8,
|
|
b_fp8.t(),
|
|
a_scale,
|
|
b_scale,
|
|
dtype,
|
|
bias=bias,
|
|
)
|
|
|
|
torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)
|