diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e6f69e2344efa..92de39418054b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_b.shape * [128, 128] == b.shape """ assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype + assert bias is None or bias.numel( + ) == b.shape[1] and bias.dtype == out_dtype - m = a.shape[0] - n = b.shape[1] + # Massage the input to be 2D + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) if current_platform.is_rocm() or not cutlass_compatible_b: from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa triton_scaled_mm) - return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + else: + out = torch.empty((a.shape[0], b.shape[1]), + dtype=out_dtype, + device=a.device) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out + return out.view(*target_shape) def cutlass_scaled_mm_azp(a: torch.Tensor, @@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.numel( ) == b.shape[1] and bias.dtype == out_dtype + + # Massage the input to be 2D + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) assert azp is None or azp.numel() == a.shape[0] - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - + out = torch.empty((a.shape[0], b.shape[1]), + dtype=out_dtype, + device=a.device) torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out + return out.view(*target_shape) def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: