mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 11:55:16 +08:00
[Bugfix] Fix 3D input passed into cutlass_scaled_mm (#22278)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
35509fc5be
commit
6a51530437
@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
assert bias is None or bias.numel(
|
||||
) == b.shape[1] and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
# Massage the input to be 2D
|
||||
target_shape = (*a.shape[:-1], b.shape[1])
|
||||
a = a.view(-1, a.shape[-1])
|
||||
|
||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
if current_platform.is_rocm() or not cutlass_compatible_b:
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
|
||||
triton_scaled_mm)
|
||||
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
else:
|
||||
out = torch.empty((a.shape[0], b.shape[1]),
|
||||
dtype=out_dtype,
|
||||
device=a.device)
|
||||
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
return out
|
||||
return out.view(*target_shape)
|
||||
|
||||
|
||||
def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.numel(
|
||||
) == b.shape[1] and bias.dtype == out_dtype
|
||||
|
||||
# Massage the input to be 2D
|
||||
target_shape = (*a.shape[:-1], b.shape[1])
|
||||
a = a.view(-1, a.shape[-1])
|
||||
assert azp is None or azp.numel() == a.shape[0]
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
out = torch.empty((a.shape[0], b.shape[1]),
|
||||
dtype=out_dtype,
|
||||
device=a.device)
|
||||
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
|
||||
azp, bias)
|
||||
return out
|
||||
return out.view(*target_shape)
|
||||
|
||||
|
||||
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user