mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:57:02 +08:00
revert input scale upper bounds
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
7d361487f7
commit
1f65cd56e5
@ -48,6 +48,7 @@ FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_scale_ub,
|
||||
]
|
||||
Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
@ -122,11 +123,12 @@ class FP8ScaledMMLinearKernel(
|
||||
pass
|
||||
|
||||
def _get_layer_params(self, layer) -> FP8ParamsT:
|
||||
w, w_s, x_s = self.layer_param_names
|
||||
w, w_s, x_s, x_s_ub = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, x_s),
|
||||
getattr(layer, x_s_ub),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -180,7 +180,7 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
cutlass_w8a8_scaled_mm_fp8,
|
||||
self.quant_fp8,
|
||||
@ -189,5 +189,6 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -77,7 +77,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
flashinfer_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
@ -86,5 +86,6 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -162,7 +162,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_per_tensor_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
@ -171,6 +171,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -215,7 +216,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_row_wise_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
@ -224,6 +225,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -255,7 +257,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_channelwise_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
@ -264,5 +266,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -132,7 +132,7 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
rocm_per_tensor_float_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
@ -141,5 +141,6 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -18,6 +18,7 @@ def apply_weights_fp8(
|
||||
w_s: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
x_s_ub: torch.Tensor | None,
|
||||
maybe_out_dtype: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
@ -36,6 +37,7 @@ def apply_weights_fp8(
|
||||
x_2d_q, x_s = quant_fp8_func(
|
||||
x_2d,
|
||||
x_s,
|
||||
x_s_ub,
|
||||
)
|
||||
|
||||
return scaled_mm_func(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user