[Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda graphs paddings) (#28492)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-11-12 10:51:43 -05:00 committed by GitHub
parent 54aecd9ed5
commit f76e85c299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,20 +115,27 @@ def _padded_cutlass(
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
has_pad = padded > dim
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
if has_pad:
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake(