[Bugfix] Fix 3D input passed into cutlass_scaled_mm (#22278)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-05 22:35:20 -04:00 committed by GitHub
parent 35509fc5be
commit 6a51530437
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_b.shape * [128, 128] == b.shape
"""
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
assert bias is None or bias.numel(
) == b.shape[1] and bias.dtype == out_dtype
m = a.shape[0]
n = b.shape[1]
# Massage the input to be 2D
target_shape = (*a.shape[:-1], b.shape[1])
a = a.view(-1, a.shape[-1])
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
triton_scaled_mm)
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
else:
out = torch.empty((a.shape[0], b.shape[1]),
dtype=out_dtype,
device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out
return out.view(*target_shape)
def cutlass_scaled_mm_azp(a: torch.Tensor,
@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.numel(
) == b.shape[1] and bias.dtype == out_dtype
# Massage the input to be 2D
target_shape = (*a.shape[:-1], b.shape[1])
a = a.view(-1, a.shape[-1])
assert azp is None or azp.numel() == a.shape[0]
m = a.shape[0]
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
out = torch.empty((a.shape[0], b.shape[1]),
dtype=out_dtype,
device=a.device)
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
azp, bias)
return out
return out.view(*target_shape)
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: