mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 17:25:16 +08:00
[Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda graphs paddings) (#28492)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
54aecd9ed5
commit
f76e85c299
@ -115,20 +115,27 @@ def _padded_cutlass(
|
||||
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
|
||||
)
|
||||
|
||||
padded_shape = [padded, *qx.shape[1:]]
|
||||
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
||||
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
||||
has_pad = padded > dim
|
||||
|
||||
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
||||
padded_x_scale = torch.ones(
|
||||
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
||||
).permute(-1, -2)
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
if has_pad:
|
||||
padded_shape = [padded, *qx.shape[1:]]
|
||||
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
||||
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
||||
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
||||
padded_x_scale = torch.ones(
|
||||
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
||||
).permute(-1, -2)
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
else:
|
||||
return cutlass_scaled_mm(
|
||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
|
||||
|
||||
def _padded_cutlass_fake(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user