From 858765f59f7485cf458084ad6a64a5e018f44d3a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 10 Nov 2025 11:17:46 +0000 Subject: [PATCH] fix output padding for torch _scaled_mm Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../quantization/kernels/scaled_mm/pytorch.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 2cd29e0905d06..dc735a115e0b1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -88,7 +88,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, - out_dtype=torch.get_default_dtype(), + out_dtype=self.out_dtype, module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 1736f145de02b..c272f579d8bcc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -44,7 +44,7 @@ def torch_per_tensor_w8a8_scaled_mm( if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) def torch_row_wise_w8a8_scaled_mm( @@ -77,7 +77,7 @@ def torch_row_wise_w8a8_scaled_mm( bias=bias, ) - output = torch.narrow(output, 0, 0, A.shape[0]) + output = torch.narrow(output, 0, 0, output_shape[0]) output = output.view(*output_shape) return output @@ -121,8 +121,8 @@ def torch_channelwise_w8a8_scaled_mm( if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, A.shape[0]) - x_scale = torch.narrow(As, 0, 0, A.shape[0]) + output = torch.narrow(output, 0, 0, output_shape[0]) + x_scale = torch.narrow(As, 0, 0, output_shape[0]) # DQ # C = sw * sx * (X * W) + bias @@ -142,6 +142,11 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): """ def get_ouput_padding(self) -> int | None: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE output_padding = 17 if pad_output else None