mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:04:57 +08:00
[ROCm][Bugfix] Fix the case where there's bias (#24895)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
de2cc3d867
commit
2891603efd
@ -5,6 +5,8 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
rocm_per_tensor_w8a8_scaled_mm_impl)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
@ -116,3 +118,32 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
current_platform.get_cu_count())
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8")
|
||||
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda")
|
||||
B = torch.rand(m, k, device="cuda")
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None
|
||||
|
||||
output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
|
||||
scale_b, bias)
|
||||
ref_out = torch._scaled_mm(A,
|
||||
B.t(),
|
||||
out_dtype=dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
assert torch.allclose(output, ref_out, rtol=0.01)
|
||||
|
||||
@ -179,7 +179,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
|
||||
bias: torch.Tensor) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user