revert input scale upper bounds

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-31 15:06:51 +00:00
parent 7d361487f7
commit 1f65cd56e5
6 changed files with 17 additions and 7 deletions

View File

@ -48,6 +48,7 @@ FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_scale_ub,
]
Int8ParamsT = tuple[
torch.Tensor, # weight
@ -122,11 +123,12 @@ class FP8ScaledMMLinearKernel(
pass
def _get_layer_params(self, layer) -> FP8ParamsT:
w, w_s, x_s = self.layer_param_names
w, w_s, x_s, x_s_ub = self.layer_param_names
return (
getattr(layer, w),
getattr(layer, w_s),
getattr(layer, x_s),
getattr(layer, x_s_ub),
)

View File

@ -180,7 +180,7 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s = self._get_layer_params(layer)
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
cutlass_w8a8_scaled_mm_fp8,
self.quant_fp8,
@ -189,5 +189,6 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -77,7 +77,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s = self._get_layer_params(layer)
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
flashinfer_w8a8_scaled_mm,
self.quant_fp8,
@ -86,5 +86,6 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -162,7 +162,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s = self._get_layer_params(layer)
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_per_tensor_w8a8_scaled_mm,
self.quant_fp8,
@ -171,6 +171,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
@ -215,7 +216,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s = self._get_layer_params(layer)
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_row_wise_w8a8_scaled_mm,
self.quant_fp8,
@ -224,6 +225,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
@ -255,7 +257,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s = self._get_layer_params(layer)
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_channelwise_w8a8_scaled_mm,
self.quant_fp8,
@ -264,5 +266,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -132,7 +132,7 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s = self._get_layer_params(layer)
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
rocm_per_tensor_float_w8a8_scaled_mm,
self.quant_fp8,
@ -141,5 +141,6 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -18,6 +18,7 @@ def apply_weights_fp8(
w_s: torch.Tensor,
x_s: torch.Tensor,
bias: torch.Tensor,
x_s_ub: torch.Tensor | None,
maybe_out_dtype: torch.dtype | None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
@ -36,6 +37,7 @@ def apply_weights_fp8(
x_2d_q, x_s = quant_fp8_func(
x_2d,
x_s,
x_s_ub,
)
return scaled_mm_func(