mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 09:42:14 +08:00
fix output padding for torch _scaled_mm
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
edb6d43a37
commit
858765f59f
@ -88,7 +88,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
activation_quant_key=activation_quant_key,
|
activation_quant_key=activation_quant_key,
|
||||||
weight_quant_key=weight_quant_key,
|
weight_quant_key=weight_quant_key,
|
||||||
out_dtype=torch.get_default_dtype(),
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__,
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,7 @@ def torch_per_tensor_w8a8_scaled_mm(
|
|||||||
if type(output) is tuple and len(output) == 2:
|
if type(output) is tuple and len(output) == 2:
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def torch_row_wise_w8a8_scaled_mm(
|
def torch_row_wise_w8a8_scaled_mm(
|
||||||
@ -77,7 +77,7 @@ def torch_row_wise_w8a8_scaled_mm(
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = torch.narrow(output, 0, 0, A.shape[0])
|
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||||
output = output.view(*output_shape)
|
output = output.view(*output_shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -121,8 +121,8 @@ def torch_channelwise_w8a8_scaled_mm(
|
|||||||
if type(output) is tuple and len(output) == 2:
|
if type(output) is tuple and len(output) == 2:
|
||||||
output = output[0]
|
output = output[0]
|
||||||
# Unpad (undo num_token_padding)
|
# Unpad (undo num_token_padding)
|
||||||
output = torch.narrow(output, 0, 0, A.shape[0])
|
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||||
x_scale = torch.narrow(As, 0, 0, A.shape[0])
|
x_scale = torch.narrow(As, 0, 0, output_shape[0])
|
||||||
|
|
||||||
# DQ
|
# DQ
|
||||||
# C = sw * sx * (X * W) + bias
|
# C = sw * sx * (X * W) + bias
|
||||||
@ -142,6 +142,11 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_ouput_padding(self) -> int | None:
|
def get_ouput_padding(self) -> int | None:
|
||||||
|
# Note: we pad the input because torch._scaled_mm is more performant
|
||||||
|
# for matrices with batch dimension > 16.
|
||||||
|
# This could change in the future.
|
||||||
|
# We also don't pad when using torch.compile,
|
||||||
|
# as it breaks with dynamic shapes.
|
||||||
vllm_config = get_current_vllm_config().compilation_config
|
vllm_config = get_current_vllm_config().compilation_config
|
||||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||||
output_padding = 17 if pad_output else None
|
output_padding = 17 if pad_output else None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user