[Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda graphs paddings) (#28492)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-11-12 10:51:43 -05:00 committed by GitHub
parent 54aecd9ed5
commit f76e85c299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,20 +115,27 @@ def _padded_cutlass(
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
) )
padded_shape = [padded, *qx.shape[1:]] has_pad = padded > dim
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded] if has_pad:
padded_x_scale = torch.ones( padded_shape = [padded, *qx.shape[1:]]
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
).permute(-1, -2) padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm( padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype padded_x_scale = torch.ones(
) padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
return output[0 : qx.shape[0], ...] ).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake( def _padded_cutlass_fake(