From 2e090bd5df974949651ad439517e0da4e981b508 Mon Sep 17 00:00:00 2001 From: rasmith Date: Thu, 12 Jun 2025 02:14:24 -0500 Subject: [PATCH] [AMD][Kernel][BugFix] fix test_rocm_compressed_tensors_w8a8 for rocm (#19509) Signed-off-by: Randall Smith --- vllm/_custom_ops.py | 7 ++----- .../quantization/compressed_tensors/triton_scaled_mm.py | 8 ++++---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d6bbfbc32886..fe5b386c4d25 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import importlib from typing import TYPE_CHECKING, Optional, Union import torch @@ -706,10 +705,8 @@ def cutlass_scaled_mm(a: torch.Tensor, cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) if current_platform.is_rocm() or not cutlass_compatible_b: - triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") - triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa + triton_scaled_mm) return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index 9bcf1aa2bc1c..d926b4c12db1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -144,10 +144,10 @@ def triton_scaled_mm(input: torch.Tensor, scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() - assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( - [M, 1]) - assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( - [N, 1]) + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 + or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 + or scale_b.shape[0] == N) assert out_dtype.is_floating_point assert bias is None or bias.is_floating_point() assert is_weak_contiguous(input)