mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:44:54 +08:00
[AMD][Kernel][BugFix] fix test_rocm_compressed_tensors_w8a8 for rocm (#19509)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
1b0b065eb5
commit
2e090bd5df
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import importlib
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -706,10 +705,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|||||||
|
|
||||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||||
if current_platform.is_rocm() or not cutlass_compatible_b:
|
if current_platform.is_rocm() or not cutlass_compatible_b:
|
||||||
triton_scaled_mm_module = importlib.import_module(
|
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
|
||||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
triton_scaled_mm)
|
||||||
"triton_scaled_mm")
|
|
||||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
|
||||||
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
|
|
||||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||||
|
|||||||
@ -144,10 +144,10 @@ def triton_scaled_mm(input: torch.Tensor,
|
|||||||
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||||
|
|
||||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||||
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
|
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1
|
||||||
[M, 1])
|
or scale_a.shape[0] == M)
|
||||||
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
|
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1
|
||||||
[N, 1])
|
or scale_b.shape[0] == N)
|
||||||
assert out_dtype.is_floating_point
|
assert out_dtype.is_floating_point
|
||||||
assert bias is None or bias.is_floating_point()
|
assert bias is None or bias.is_floating_point()
|
||||||
assert is_weak_contiguous(input)
|
assert is_weak_contiguous(input)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user