From 2891603efdbd38ac7197d5f41b0e245fb6a53b82 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 15 Sep 2025 22:05:12 -0400 Subject: [PATCH] [ROCm][Bugfix] Fix the case where there's bias (#24895) Signed-off-by: Gregory Shtrasberg --- .../quantization/test_rocm_skinny_gemms.py | 31 +++++++++++++++++++ .../layers/quantization/utils/w8a8_utils.py | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 03d5d98739c5..a9b1c71ef071 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -5,6 +5,8 @@ import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + rocm_per_tensor_w8a8_scaled_mm_impl) from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] @@ -116,3 +118,32 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): current_platform.get_cu_count()) assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="only test for rocm fp8") +def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias): + torch.manual_seed(seed) + + A = torch.rand(n, k, device="cuda") + B = torch.rand(m, k, device="cuda") + + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + + bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None + + output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a, + scale_b, bias) + ref_out = torch._scaled_mm(A, + B.t(), + out_dtype=dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + assert torch.allclose(output, ref_out, rtol=0.01) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index e89a5e643b0e..8cda1789e6c9 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -179,7 +179,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( - ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) else: