fix output padding for torch _scaled_mm

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-10 11:17:46 +00:00
parent edb6d43a37
commit 858765f59f
2 changed files with 10 additions and 5 deletions

View File

@ -88,7 +88,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=torch.get_default_dtype(),
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)

View File

@ -44,7 +44,7 @@ def torch_per_tensor_w8a8_scaled_mm(
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
def torch_row_wise_w8a8_scaled_mm(
@ -77,7 +77,7 @@ def torch_row_wise_w8a8_scaled_mm(
bias=bias,
)
output = torch.narrow(output, 0, 0, A.shape[0])
output = torch.narrow(output, 0, 0, output_shape[0])
output = output.view(*output_shape)
return output
@ -121,8 +121,8 @@ def torch_channelwise_w8a8_scaled_mm(
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, A.shape[0])
x_scale = torch.narrow(As, 0, 0, A.shape[0])
output = torch.narrow(output, 0, 0, output_shape[0])
x_scale = torch.narrow(As, 0, 0, output_shape[0])
# DQ
# C = sw * sx * (X * W) + bias
@ -142,6 +142,11 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
def get_ouput_padding(self) -> int | None:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
output_padding = 17 if pad_output else None