# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm if not current_platform.has_device_capability(100): pytest.skip( reason= "Flashinfer FP8 gemms requires compute capability of 10.0 or above.", allow_module_level=True, ) DTYPES = [torch.float16, torch.bfloat16] # m, n, k SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] SHAPES.extend(PAD_SHAPES) SEEDS = [42] CUDA_DEVICES = ["cuda:0"] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("autotune", [False, True]) @torch.inference_mode() def test_flashinfer_fp8_gemm( dtype: torch.dtype, shape: tuple[int, int, int], use_bias: bool, seed: int, device: str, autotune: bool, ) -> None: current_platform.seed_everything(seed) m, n, k = shape a = torch.randn((m, k), dtype=dtype, device=device) b = torch.randn((n, k), dtype=dtype, device=device) / k a_fp8, a_scale = ops.scaled_fp8_quant(a) b_fp8, b_scale = ops.scaled_fp8_quant(b) expected_out = torch.mm( a_scale * a_fp8.to(dtype=torch.float32), b_scale * b_fp8.to(dtype=torch.float32).t(), ).to(dtype=dtype) if use_bias: bias = torch.randn((n, ), dtype=dtype, device=device) expected_out = expected_out + bias else: bias = None import flashinfer with flashinfer.autotune(autotune): out = flashinfer_scaled_fp8_mm( a_fp8, b_fp8.t(), a_scale, b_scale, dtype, bias=bias, ) torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)