[ROCm][Bugfix] Fix the case where there's bias (#24895)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-09-15 22:05:12 -04:00 committed by GitHub
parent de2cc3d867
commit 2891603efd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 1 deletions

View File

@ -5,6 +5,8 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
rocm_per_tensor_w8a8_scaled_mm_impl)
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
@ -116,3 +118,32 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
current_platform.get_cu_count())
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias):
torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None
output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
scale_b, bias)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
assert torch.allclose(output, ref_out, rtol=0.01)

View File

@ -179,7 +179,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else: