[AMD][Kernel][BugFix] fix test_rocm_compressed_tensors_w8a8 for rocm (#19509)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith 2025-06-12 02:14:24 -05:00 committed by GitHub
parent 1b0b065eb5
commit 2e090bd5df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 9 deletions

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import importlib
from typing import TYPE_CHECKING, Optional, Union
import torch
@ -706,10 +705,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
triton_scaled_mm)
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
out = torch.empty((m, n), dtype=out_dtype, device=a.device)

View File

@ -144,10 +144,10 @@ def triton_scaled_mm(input: torch.Tensor,
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
[N, 1])
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1
or scale_a.shape[0] == M)
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1
or scale_b.shape[0] == N)
assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)