mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[AMD][Kernel][BugFix] fix test_rocm_compressed_tensors_w8a8 for rocm (#19509)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
1b0b065eb5
commit
2e090bd5df
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -706,10 +705,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
||||
|
||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
if current_platform.is_rocm() or not cutlass_compatible_b:
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
|
||||
triton_scaled_mm)
|
||||
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
@ -144,10 +144,10 @@ def triton_scaled_mm(input: torch.Tensor,
|
||||
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||
|
||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
|
||||
[M, 1])
|
||||
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
|
||||
[N, 1])
|
||||
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1
|
||||
or scale_a.shape[0] == M)
|
||||
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1
|
||||
or scale_b.shape[0] == N)
|
||||
assert out_dtype.is_floating_point
|
||||
assert bias is None or bias.is_floating_point()
|
||||
assert is_weak_contiguous(input)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user