From f76e85c29984df2b0312efa5bfb80218689b9e17 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:51:43 -0500 Subject: [PATCH] [Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda graphs paddings) (#28492) Signed-off-by: Alexander Matveev --- .../layers/quantization/utils/fp8_utils.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 0c54cf4def005..4384857f9270d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -115,20 +115,27 @@ def _padded_cutlass( dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) ) - padded_shape = [padded, *qx.shape[1:]] - padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) - padded_qx[0 : qx.shape[0], ...].copy_(qx) + has_pad = padded > dim - padded_x_scale_shape = [*x_scale.shape[1:], padded] - padded_x_scale = torch.ones( - padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype - ).permute(-1, -2) - padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) + if has_pad: + padded_shape = [padded, *qx.shape[1:]] + padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) + padded_qx[0 : qx.shape[0], ...].copy_(qx) - output = cutlass_scaled_mm( - padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype - ) - return output[0 : qx.shape[0], ...] + padded_x_scale_shape = [*x_scale.shape[1:], padded] + padded_x_scale = torch.ones( + padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype + ).permute(-1, -2) + padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) + + output = cutlass_scaled_mm( + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype + ) + return output[0 : qx.shape[0], ...] + else: + return cutlass_scaled_mm( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) def _padded_cutlass_fake(