mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 07:17:04 +08:00
fix output padding for torch _scaled_mm
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
edb6d43a37
commit
858765f59f
@ -88,7 +88,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ def torch_per_tensor_w8a8_scaled_mm(
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_row_wise_w8a8_scaled_mm(
|
||||
@ -77,7 +77,7 @@ def torch_row_wise_w8a8_scaled_mm(
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
output = torch.narrow(output, 0, 0, A.shape[0])
|
||||
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
@ -121,8 +121,8 @@ def torch_channelwise_w8a8_scaled_mm(
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, A.shape[0])
|
||||
x_scale = torch.narrow(As, 0, 0, A.shape[0])
|
||||
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||
x_scale = torch.narrow(As, 0, 0, output_shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
@ -142,6 +142,11 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
"""
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
output_padding = 17 if pad_output else None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user