revert input scale upper bounds

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-31 15:06:51 +00:00
parent 7d361487f7
commit 1f65cd56e5
6 changed files with 17 additions and 7 deletions

View File

@ -48,6 +48,7 @@ FP8ParamsT = tuple[
torch.Tensor, # weight torch.Tensor, # weight
torch.Tensor, # weight_scale torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale, torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_scale_ub,
] ]
Int8ParamsT = tuple[ Int8ParamsT = tuple[
torch.Tensor, # weight torch.Tensor, # weight
@ -122,11 +123,12 @@ class FP8ScaledMMLinearKernel(
pass pass
def _get_layer_params(self, layer) -> FP8ParamsT: def _get_layer_params(self, layer) -> FP8ParamsT:
w, w_s, x_s = self.layer_param_names w, w_s, x_s, x_s_ub = self.layer_param_names
return ( return (
getattr(layer, w), getattr(layer, w),
getattr(layer, w_s), getattr(layer, w_s),
getattr(layer, x_s), getattr(layer, x_s),
getattr(layer, x_s_ub),
) )

View File

@ -180,7 +180,7 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
w, w_s, x_s = self._get_layer_params(layer) w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( return apply_weights_fp8(
cutlass_w8a8_scaled_mm_fp8, cutlass_w8a8_scaled_mm_fp8,
self.quant_fp8, self.quant_fp8,
@ -189,5 +189,6 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
x_s_ub,
self.config.out_dtype, self.config.out_dtype,
) )

View File

@ -77,7 +77,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
w, w_s, x_s = self._get_layer_params(layer) w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( return apply_weights_fp8(
flashinfer_w8a8_scaled_mm, flashinfer_w8a8_scaled_mm,
self.quant_fp8, self.quant_fp8,
@ -86,5 +86,6 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
x_s_ub,
self.config.out_dtype, self.config.out_dtype,
) )

View File

@ -162,7 +162,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
w, w_s, x_s = self._get_layer_params(layer) w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( return apply_weights_fp8(
torch_per_tensor_w8a8_scaled_mm, torch_per_tensor_w8a8_scaled_mm,
self.quant_fp8, self.quant_fp8,
@ -171,6 +171,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
x_s_ub,
self.config.out_dtype, self.config.out_dtype,
) )
@ -215,7 +216,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
w, w_s, x_s = self._get_layer_params(layer) w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( return apply_weights_fp8(
torch_row_wise_w8a8_scaled_mm, torch_row_wise_w8a8_scaled_mm,
self.quant_fp8, self.quant_fp8,
@ -224,6 +225,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
x_s_ub,
self.config.out_dtype, self.config.out_dtype,
) )
@ -255,7 +257,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
w, w_s, x_s = self._get_layer_params(layer) w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( return apply_weights_fp8(
torch_channelwise_w8a8_scaled_mm, torch_channelwise_w8a8_scaled_mm,
self.quant_fp8, self.quant_fp8,
@ -264,5 +266,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
x_s_ub,
self.config.out_dtype, self.config.out_dtype,
) )

View File

@ -132,7 +132,7 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
w, w_s, x_s = self._get_layer_params(layer) w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( return apply_weights_fp8(
rocm_per_tensor_float_w8a8_scaled_mm, rocm_per_tensor_float_w8a8_scaled_mm,
self.quant_fp8, self.quant_fp8,
@ -141,5 +141,6 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
x_s_ub,
self.config.out_dtype, self.config.out_dtype,
) )

View File

@ -18,6 +18,7 @@ def apply_weights_fp8(
w_s: torch.Tensor, w_s: torch.Tensor,
x_s: torch.Tensor, x_s: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
x_s_ub: torch.Tensor | None,
maybe_out_dtype: torch.dtype | None, maybe_out_dtype: torch.dtype | None,
) -> torch.Tensor: ) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant. # ops.scaled_fp8_quant supports both dynamic and static quant.
@ -36,6 +37,7 @@ def apply_weights_fp8(
x_2d_q, x_s = quant_fp8_func( x_2d_q, x_s = quant_fp8_func(
x_2d, x_2d,
x_s, x_s,
x_s_ub,
) )
return scaled_mm_func( return scaled_mm_func(